Compare commits
7 Commits
301ad12241
...
v0.11.0-v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cea31d16fb | ||
|
|
01bafad6d0 | ||
|
|
070bfa4a73 | ||
|
|
fc48b79ae9 | ||
|
|
bd8c999335 | ||
|
|
9b1f25fbe3 | ||
|
|
42c7ef2f27 |
@@ -1,76 +1,76 @@
|
|||||||
name: e2e-test
|
# name: e2e-test
|
||||||
|
|
||||||
on:
|
# on:
|
||||||
workflow_call:
|
# workflow_call:
|
||||||
pull_request:
|
# pull_request:
|
||||||
branches: [main]
|
# branches: [main]
|
||||||
types: [opened, synchronize, reopened]
|
# types: [opened, synchronize, reopened]
|
||||||
push:
|
# push:
|
||||||
branches: [main]
|
# branches: [main]
|
||||||
|
|
||||||
concurrency:
|
# concurrency:
|
||||||
group: e2e-singlecard
|
# group: e2e-singlecard
|
||||||
cancel-in-progress: false
|
# cancel-in-progress: false
|
||||||
|
|
||||||
jobs:
|
# jobs:
|
||||||
e2e:
|
# e2e:
|
||||||
name: e2e-test-singlecard
|
# name: e2e-test-singlecard
|
||||||
runs-on:
|
# runs-on:
|
||||||
- self-hosted
|
# - self-hosted
|
||||||
- Linux
|
# - Linux
|
||||||
- X64
|
# - X64
|
||||||
|
|
||||||
steps:
|
# steps:
|
||||||
- name: Checkout PR code
|
# - name: Checkout PR code
|
||||||
uses: actions/checkout@v4
|
# uses: actions/checkout@v4
|
||||||
with:
|
# with:
|
||||||
fetch-depth: 0
|
# fetch-depth: 0
|
||||||
|
|
||||||
- name: Verify PR workspace
|
# - name: Verify PR workspace
|
||||||
run: |
|
# run: |
|
||||||
echo "===== WORKSPACE ====="
|
# echo "===== WORKSPACE ====="
|
||||||
pwd
|
# pwd
|
||||||
ls -l
|
# ls -l
|
||||||
echo "===== GIT INFO ====="
|
# echo "===== GIT INFO ====="
|
||||||
git rev-parse HEAD
|
# git rev-parse HEAD
|
||||||
git log -1 --oneline
|
# git log -1 --oneline
|
||||||
git status --porcelain
|
# git status --porcelain
|
||||||
|
|
||||||
- name: Start docker
|
# - name: Start docker
|
||||||
run: |
|
# run: |
|
||||||
bash ci/scripts/docker/start_docker.sh
|
# bash ci/scripts/docker/start_docker.sh
|
||||||
|
|
||||||
- name: Install enviroments
|
# - name: Install enviroments
|
||||||
env:
|
# env:
|
||||||
PROXY_URL: ${{ secrets.PROXY_URL }}
|
# PROXY_URL: ${{ secrets.PROXY_URL }}
|
||||||
NO_PROXY_LIST: ${{ secrets.NO_PROXY_LIST }}
|
# NO_PROXY_LIST: ${{ secrets.NO_PROXY_LIST }}
|
||||||
run: |
|
# run: |
|
||||||
bash ci/scripts/env/install_env.sh
|
# bash ci/scripts/env/install_env.sh
|
||||||
|
|
||||||
- name: Start vLLM server
|
# - name: Start vLLM server
|
||||||
run: |
|
# run: |
|
||||||
bash ci/scripts/server/start_vllm.sh
|
# bash ci/scripts/server/start_vllm.sh
|
||||||
|
|
||||||
- name: Wait for vLLM ready
|
# - name: Wait for vLLM ready
|
||||||
run: |
|
# run: |
|
||||||
bash ci/scripts/server/wait_vllm.sh
|
# bash ci/scripts/server/wait_vllm.sh
|
||||||
|
|
||||||
- name: API Test
|
# - name: API Test
|
||||||
run: |
|
# run: |
|
||||||
docker exec aiak-e2e-singlecard bash -lc '
|
# docker exec aiak-e2e-singlecard bash -lc '
|
||||||
curl http://localhost:8356/v1/chat/completions \
|
# curl http://localhost:8356/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
# -H "Content-Type: application/json" \
|
||||||
-d @- << "EOF"
|
# -d @- << "EOF"
|
||||||
{
|
# {
|
||||||
"model": "Qwen3-8B",
|
# "model": "Qwen3-8B",
|
||||||
"messages": [
|
# "messages": [
|
||||||
{ "role": "user", "content": "Who are you?" }
|
# { "role": "user", "content": "Who are you?" }
|
||||||
],
|
# ],
|
||||||
"max_tokens": 200,
|
# "max_tokens": 200,
|
||||||
"temperature": 0
|
# "temperature": 0
|
||||||
}
|
# }
|
||||||
EOF
|
# EOF
|
||||||
'
|
# '
|
||||||
|
|
||||||
# - name: Accuracy testing
|
# - name: Accuracy testing
|
||||||
# run: |
|
# run: |
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ This document describes how to install vllm-kunlun manually.
|
|||||||
- vLLM (same version as vllm-kunlun)
|
- vLLM (same version as vllm-kunlun)
|
||||||
|
|
||||||
## Setup environment using container
|
## Setup environment using container
|
||||||
We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:base_v0.0.2` and `wjie520/vllm_kunlun:base_mimo_v0.0.2`(Only MIMO_V2 and GPT-OSS).You can pull it using the `docker pull` command.
|
We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:uv_base`.You can pull it using the `docker pull wjie520/vllm_kunlun:uv_base` command.
|
||||||
|
|
||||||
|
We also provide images with xpytorch and ops installed.You can pull it using the `wjie520/vllm_kunlun:base_v0.0.2 and wjie520/vllm_kunlun:base_mimo_v0.0.2 (Only MIMO_V2 and GPT-OSS)` command
|
||||||
### Container startup script
|
### Container startup script
|
||||||
|
|
||||||
:::::{tab-set}
|
:::::{tab-set}
|
||||||
@@ -19,9 +21,8 @@ We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:base_v0.
|
|||||||
|
|
||||||
::::{tab-item} start_docker.sh
|
::::{tab-item} start_docker.sh
|
||||||
:selected:
|
:selected:
|
||||||
:sync: pip
|
:sync: uv pip
|
||||||
```{code-block} bash
|
```{code-block} bash
|
||||||
:substitutions:
|
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
XPU_NUM=8
|
XPU_NUM=8
|
||||||
DOCKER_DEVICE_CONFIG=""
|
DOCKER_DEVICE_CONFIG=""
|
||||||
@@ -31,7 +32,7 @@ if [ $XPU_NUM -gt 0 ]; then
|
|||||||
done
|
done
|
||||||
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
|
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
|
||||||
fi
|
fi
|
||||||
export build_image="wjie520/vllm_kunlun:base_v0.0.2"
|
export build_image="wjie520/vllm_kunlun:uv_base"
|
||||||
# or export build_image="iregistry.baidu-int.com/xmlir/xmlir_ubuntu_2004_x86_64:v0.32"
|
# or export build_image="iregistry.baidu-int.com/xmlir/xmlir_ubuntu_2004_x86_64:v0.32"
|
||||||
|
|
||||||
docker run -itd ${DOCKER_DEVICE_CONFIG} \
|
docker run -itd ${DOCKER_DEVICE_CONFIG} \
|
||||||
@@ -63,8 +64,71 @@ uv pip install -r requirements.txt
|
|||||||
python setup.py build
|
python setup.py build
|
||||||
|
|
||||||
python setup.py install
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
### Replace eval_frame.py
|
||||||
|
Copy the eval_frame.py patch:
|
||||||
|
```
|
||||||
|
cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/vllm_kunlun_0.10.1.1/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Choose to download customized xpytorch
|
||||||
|
|
||||||
|
### Install the KL3-customized build of PyTorch
|
||||||
|
```
|
||||||
|
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-02T05%3A01%3A27Z%2F-1%2Fhost%2Ff3cf499234f82303891aed2bcb0628918e379a21e841a3fac6bd94afef491ff7
|
||||||
|
(for the conda)
|
||||||
|
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||||
|
(for the uv)
|
||||||
|
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
|
||||||
|
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
|
||||||
|
```
|
||||||
|
### Install the KL3-customized build of PyTorch (Only MIMO V2)
|
||||||
|
```
|
||||||
|
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1231/xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||||
|
(for the conda)
|
||||||
|
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||||
|
(for the uv)
|
||||||
|
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
|
||||||
|
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Install the KL3-customized build of PyTorch (Only DeepSeek-V3.2-Exp-w8a8)
|
||||||
|
```
|
||||||
|
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://aihc-private-hcd.bj.bcebos.com/v1/vllm-kunlun-ds/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2026-02-03T01%3A59%3A40Z%2F-1%2Fhost%2Ffc4b6f5b83c2fde70d48fdfc23c40c396efc9cb3c36d6f811fdca5f109073321
|
||||||
|
(for the conda)
|
||||||
|
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||||
|
(for the uv)
|
||||||
|
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
|
||||||
|
mv torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl torch_xray-2.0.3-cp310-cp310-linux_x86_64.whl && \
|
||||||
|
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g; s/torch_xray-999.9.9/torch_xray-2.0.3/' setup.sh && bash setup.sh
|
||||||
|
```
|
||||||
|
## Choose to download customized ops
|
||||||
|
|
||||||
|
### Install custom ops
|
||||||
|
```
|
||||||
|
uv pip install "https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xtorch_ops-0.1.2209%2B6752ad20-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-05T06%3A18%3A00Z%2F-1%2Fhost%2F14936c2b7e7c557c1400e4c467c79f7a9217374a7aa4a046711ac4d948f460cd"
|
||||||
|
```
|
||||||
|
### Install custom ops (Only MIMO V2)
|
||||||
|
```
|
||||||
|
uv pip install "https://vllm-ai-models.bj.bcebos.com/v1/vLLM-Kunlun/ops/swa/xtorch_ops-0.1.2109%252B523cb26d-cp310-cp310-linux_x86_64.whl"
|
||||||
|
```
|
||||||
|
### Install custom ops (Only DeepSeek-V3.2-Exp-w8a8)
|
||||||
|
```
|
||||||
|
uv pip install "https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1215/xtorch_ops-0.1.2263%2Bc030eebd-cp310-cp310-linux_x86_64.whl"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Install the KLX3 custom Triton build
|
||||||
|
```
|
||||||
|
uv pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.0.0%2Bb2cde523-cp310-cp310-linux_x86_64.whl"
|
||||||
|
```
|
||||||
|
## Install the AIAK custom ops library
|
||||||
|
```
|
||||||
|
uv pip install "https://vllm-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20260130_152557/xspeedgate_ops-0.0.0%2Be5cdcbe-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKhvtgrTA8US5LIc8Vbl0mP%2F2026-01-30T10%3A33%3A32Z%2F2592000%2Fhost%2F3c13d67cc61d0df7538c198f5c32422f3b034068a40eef43cb51b079cc6f0555" --force-reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### Set up the environment
|
### Set up the environment
|
||||||
@@ -81,7 +145,6 @@ chmod +x /workspace/vLLM-Kunlun/setup_env.sh && source /workspace/vLLM-Kunlun/se
|
|||||||
:selected:
|
:selected:
|
||||||
:sync: pip
|
:sync: pip
|
||||||
```{code-block} bash
|
```{code-block} bash
|
||||||
:substitutions:
|
|
||||||
python -m vllm.entrypoints.openai.api_server \
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
--port 8356 \
|
--port 8356 \
|
||||||
@@ -112,41 +175,3 @@ python -m vllm.entrypoints.openai.api_server \
|
|||||||
```
|
```
|
||||||
::::
|
::::
|
||||||
:::::
|
:::::
|
||||||
|
|
||||||
|
|
||||||
### xpytorch and ops install
|
|
||||||
We also provide xpytorch and ops link for custom installation.
|
|
||||||
|
|
||||||
### Replace eval_frame.py
|
|
||||||
Copy the eval_frame.py patch:
|
|
||||||
```
|
|
||||||
cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/vllm_kunlun_0.10.1.1/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
|
|
||||||
```
|
|
||||||
## Install the KL3-customized build of PyTorch
|
|
||||||
```
|
|
||||||
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-02T05%3A01%3A27Z%2F-1%2Fhost%2Ff3cf499234f82303891aed2bcb0628918e379a21e841a3fac6bd94afef491ff7
|
|
||||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
|
||||||
```
|
|
||||||
## Install the KL3-customized build of PyTorch(Only MIMO V2)
|
|
||||||
```
|
|
||||||
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1231/xpytorch-cp310-torch251-ubuntu2004-x64.run
|
|
||||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
|
||||||
```
|
|
||||||
|
|
||||||
## Install custom ops
|
|
||||||
```
|
|
||||||
pip install "https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xtorch_ops-0.1.2209%2B6752ad20-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-05T06%3A18%3A00Z%2F-1%2Fhost%2F14936c2b7e7c557c1400e4c467c79f7a9217374a7aa4a046711ac4d948f460cd"
|
|
||||||
```
|
|
||||||
## Install custom ops(Only MIMO V2)
|
|
||||||
```
|
|
||||||
pip install "https://vllm-ai-models.bj.bcebos.com/v1/vLLM-Kunlun/ops/swa/xtorch_ops-0.1.2109%252B523cb26d-cp310-cp310-linux_x86_64.whl"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Install the KLX3 custom Triton build
|
|
||||||
```
|
|
||||||
pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.0.0%2Bb2cde523-cp310-cp310-linux_x86_64.whl"
|
|
||||||
```
|
|
||||||
## Install the AIAK custom ops library
|
|
||||||
```
|
|
||||||
pip install "https://cce-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20251219_152418/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl"
|
|
||||||
```
|
|
||||||
|
|||||||
140
docs/source/tutorials/multi_xpu_DeepSeek-V3.2-Exp-w8a8.md
Normal file
140
docs/source/tutorials/multi_xpu_DeepSeek-V3.2-Exp-w8a8.md
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
# Multi XPU (DeepSeek-V3.2-Exp-w8a8)
|
||||||
|
|
||||||
|
## Run vllm-kunlun on Multi XPU
|
||||||
|
|
||||||
|
Setup environment using container:
|
||||||
|
|
||||||
|
Please follow the [installation.md](../installation.md) document to set up the environment first.
|
||||||
|
|
||||||
|
Create a container
|
||||||
|
```bash
|
||||||
|
# !/bin/bash
|
||||||
|
# rundocker.sh
|
||||||
|
XPU_NUM=8
|
||||||
|
DOCKER_DEVICE_CONFIG=""
|
||||||
|
if [ $XPU_NUM -gt 0 ]; then
|
||||||
|
for idx in $(seq 0 $((XPU_NUM-1))); do
|
||||||
|
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
|
||||||
|
done
|
||||||
|
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
|
||||||
|
fi
|
||||||
|
|
||||||
|
export build_image="xxx"
|
||||||
|
|
||||||
|
docker run -itd ${DOCKER_DEVICE_CONFIG} \
|
||||||
|
--net=host \
|
||||||
|
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||||
|
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
|
||||||
|
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
|
||||||
|
--name "$1" \
|
||||||
|
-w /workspace \
|
||||||
|
"$build_image" /bin/bash
|
||||||
|
```
|
||||||
|
|
||||||
|
### Preparation Weight
|
||||||
|
|
||||||
|
- Pull DeepSeek-V3.2-Exp-w8a8-int8 weights
|
||||||
|
```
|
||||||
|
wget -O DeepSeek-V3.2-Exp-w8a8-int8.tar.gz https://aihc-private-hcd.bj.bcebos.com/v1/LLM/DeepSeek/DeepSeek-V3.2-Exp-w8a8-int8.tar.gz?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2025-12-24T06%3A07%3A10Z%2F-1%2Fhost%2Fa324bf469176934a05f75d3acabc3c1fb891be150f43fb1976e65b7ec68733db
|
||||||
|
```
|
||||||
|
- Ensure that the field "quantization_config" is included.If not, deployment will result in an OOM (Out of Memory) error.
|
||||||
|
|
||||||
|
vim model/DeepSeek-V3.2-Exp-w8a8-int8/config.json
|
||||||
|
```config.json
|
||||||
|
"quantization_config": {
|
||||||
|
"config_groups": {
|
||||||
|
"group_0": {
|
||||||
|
"format": "int-quantized",
|
||||||
|
"input_activations": {
|
||||||
|
"actorder": null,
|
||||||
|
"block_structure": null,
|
||||||
|
"dynamic": true,
|
||||||
|
"group_size": null,
|
||||||
|
"num_bits": 8,
|
||||||
|
"observer": null,
|
||||||
|
"observer_kwargs": {},
|
||||||
|
"strategy": "token",
|
||||||
|
"symmetric": true,
|
||||||
|
"type": "int"
|
||||||
|
},
|
||||||
|
"output_activations": null,
|
||||||
|
"targets": [
|
||||||
|
"Linear"
|
||||||
|
],
|
||||||
|
"weights": {
|
||||||
|
"actorder": null,
|
||||||
|
"block_structure": null,
|
||||||
|
"dynamic": false,
|
||||||
|
"group_size": null,
|
||||||
|
"num_bits": 8,
|
||||||
|
"observer": "minmax",
|
||||||
|
"observer_kwargs": {},
|
||||||
|
"strategy": "channel",
|
||||||
|
"symmetric": true,
|
||||||
|
"type": "int"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"format": "int-quantized",
|
||||||
|
"global_compression_ratio": null,
|
||||||
|
"ignore": [
|
||||||
|
"lm_head"
|
||||||
|
],
|
||||||
|
"kv_cache_scheme": null,
|
||||||
|
"quant_method": "compressed-tensors",
|
||||||
|
"quantization_status": "compressed",
|
||||||
|
"sparsity_config": {},
|
||||||
|
"transform_config": {},
|
||||||
|
"version": "0.12.2"
|
||||||
|
},
|
||||||
|
```
|
||||||
|
|
||||||
|
### Online Serving on Multi XPU
|
||||||
|
|
||||||
|
Start the vLLM server on multi XPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
unset XPU_DUMMY_EVENT && \
|
||||||
|
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && \
|
||||||
|
export XMLIR_CUDNN_ENABLED=1 && \
|
||||||
|
export XPU_USE_DEFAULT_CTX=1 && \
|
||||||
|
export XMLIR_FORCE_USE_XPU_GRAPH=1 && \
|
||||||
|
export XMLIR_ENABLE_FAST_FC=1 && \
|
||||||
|
export XPU_USE_FAST_SWIGLU=1 && \
|
||||||
|
export CUDA_GRAPH_OPTIMIZE_STREAM=1 && \
|
||||||
|
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false && \
|
||||||
|
export XPU_USE_MOE_SORTED_THRES=1 && \
|
||||||
|
export USE_ORI_ROPE=1 && \
|
||||||
|
export VLLM_USE_V1=1
|
||||||
|
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8806 \
|
||||||
|
--model /data/DeepSeek-V3.2-Exp-w8a8-int8 \
|
||||||
|
--gpu-memory-utilization 0.95 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--max-model-len 32768 \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--dtype float16 \
|
||||||
|
--max_num_seqs 32 \
|
||||||
|
--max_num_batched_tokens 8192 \
|
||||||
|
--block-size 64 \
|
||||||
|
--no-enable-chunked-prefill \
|
||||||
|
--distributed-executor-backend mp \
|
||||||
|
--disable-log-requests \
|
||||||
|
--no-enable-prefix-caching --kv-cache-dtype bfloat16 \
|
||||||
|
--compilation-config '{"splitting_ops":["vllm.unified_attention",
|
||||||
|
"vllm.unified_attention_with_output",
|
||||||
|
"vllm.unified_attention_with_output_kunlun",
|
||||||
|
"vllm.mamba_mixer2",
|
||||||
|
"vllm.mamba_mixer",
|
||||||
|
"vllm.short_conv",
|
||||||
|
"vllm.linear_attention",
|
||||||
|
"vllm.plamo2_mamba_mixer",
|
||||||
|
"vllm.gdn_attention",
|
||||||
|
"vllm.sparse_attn_indexer",
|
||||||
|
"vllm.sparse_attn_indexer_vllm_kunlun"]}'
|
||||||
|
```
|
||||||
@@ -34,7 +34,9 @@ static inline std::string get_shm_name() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static constexpr uint32_t heartbeat_us = 1000; // microseconds
|
static constexpr uint32_t heartbeat_us = 1000; // microseconds
|
||||||
static constexpr uint32_t heartbeat_timeout_us = 20 * heartbeat_us;
|
static constexpr uint32_t heartbeat_check_everyN = 50;
|
||||||
|
static constexpr uint32_t heartbeat_timeout_us =
|
||||||
|
heartbeat_check_everyN * heartbeat_us;
|
||||||
|
|
||||||
struct alignas(64) WorkerHeartBeat {
|
struct alignas(64) WorkerHeartBeat {
|
||||||
std::atomic<uint64_t> timestamp;
|
std::atomic<uint64_t> timestamp;
|
||||||
|
|||||||
@@ -51,17 +51,16 @@ void ShmManager::set_xpu_info(int device_id, uint32_t xpu_pci_addr,
|
|||||||
void ShmManager::run_busy_loop() {
|
void ShmManager::run_busy_loop() {
|
||||||
spdlog::info("ShmManager busy loop started");
|
spdlog::info("ShmManager busy loop started");
|
||||||
|
|
||||||
int heart_beat_check_everyN = 20;
|
|
||||||
int loop_cnt = 0;
|
int loop_cnt = 0;
|
||||||
|
|
||||||
while (!stop_loop_flag.load(std::memory_order_acquire)) {
|
while (!stop_loop_flag.load(std::memory_order_acquire)) {
|
||||||
process_requests();
|
process_requests();
|
||||||
|
|
||||||
if (loop_cnt % heart_beat_check_everyN == 0) {
|
if (loop_cnt % heartbeat_check_everyN== 0) {
|
||||||
check_heart_beats();
|
check_heart_beats();
|
||||||
}
|
}
|
||||||
|
loop_cnt = (loop_cnt + 1) % heartbeat_check_everyN;
|
||||||
|
|
||||||
loop_cnt = (loop_cnt + 1) % heart_beat_check_everyN;
|
|
||||||
usleep(heartbeat_us);
|
usleep(heartbeat_us);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,86 +1,94 @@
|
|||||||
"""kunlun_ops for lora"""
|
"""kunlun_ops for lora"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import xspeedgate_ops
|
|
||||||
import time
|
|
||||||
from torch._C import dtype
|
|
||||||
import os
|
|
||||||
from torch._dynamo import disable
|
|
||||||
|
|
||||||
|
|
||||||
def sgmv_shrink(
|
def sgmv_shrink(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_a_weights: torch.Tensor,
|
lora_a_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
expert_m: torch.Tensor,
|
expert_m: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
token_nums: int,
|
token_nums: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
sgmv_shrink
|
sgmv_shrink
|
||||||
"""
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||||
|
inputs,
|
||||||
return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling)
|
lora_a_weights,
|
||||||
|
seq_len_tensor.to(torch.int32),
|
||||||
|
lora_indices_tensor.to(torch.int32),
|
||||||
|
output_tensor,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand(inputs: torch.Tensor,
|
def sgmv_expand(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
batches: int,
|
lora_indices_tensor: torch.Tensor,
|
||||||
max_seq_length: int,
|
batches: int,
|
||||||
token_nums: int,
|
max_seq_length: int,
|
||||||
add_inputs: bool = False):
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
sgmv_expand
|
sgmv_expand
|
||||||
"""
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0)
|
inputs,
|
||||||
|
lora_b_weights,
|
||||||
|
seq_len_tensor.to(torch.int32),
|
||||||
|
lora_indices_tensor.to(torch.int32),
|
||||||
|
output_tensor,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
def sgmv_expand_slice(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
normed_scale: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
b_seq_start_loc: torch.Tensor,
|
normed_scale: torch.Tensor,
|
||||||
seq_len_tensor: torch.Tensor,
|
b_seq_start_loc: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
seq_len_tensor: torch.Tensor,
|
||||||
batches: int,
|
lora_indices_tensor: torch.Tensor,
|
||||||
max_seq_length: int,
|
batches: int,
|
||||||
token_nums: int,
|
max_seq_length: int,
|
||||||
slice_offset: int,
|
token_nums: int,
|
||||||
slice_size: int,
|
slice_offset: int,
|
||||||
add_inputs: bool = False):
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
sgmv_expand_slice
|
sgmv_expand_slice
|
||||||
"""
|
"""
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
|
inputs,
|
||||||
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset)
|
lora_b_weights,
|
||||||
|
seq_len_tensor.to(torch.int32),
|
||||||
|
lora_indices_tensor.to(torch.int32),
|
||||||
|
output_tensor,
|
||||||
|
slice_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_shrink(
|
def bgmv_shrink(
|
||||||
@@ -92,27 +100,33 @@ def bgmv_shrink(
|
|||||||
moe_index: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
expert_m: torch.Tensor,
|
expert_m: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor, # [m]
|
lora_indices_tensor: torch.Tensor, # [m]
|
||||||
scaling: float = 1.0
|
scaling: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
bgmv_shrink
|
bgmv_shrink
|
||||||
"""
|
"""
|
||||||
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling)
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||||
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand(inputs: torch.Tensor,
|
def bgmv_expand(
|
||||||
lora_b_weights: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
block_statistic: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
sorted_tokens_num_lod: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
moe_index: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
moe_index: torch.Tensor,
|
||||||
add_inputs: bool = True):
|
lora_indices_tensor: torch.Tensor,
|
||||||
""""
|
add_inputs: bool = True,
|
||||||
bgmv_expand
|
):
|
||||||
|
""" "
|
||||||
|
bgmv_expand
|
||||||
"""
|
"""
|
||||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0)
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
# @my_wrapper
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def bgmv_expand_slice(
|
def bgmv_expand_slice(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
@@ -125,9 +139,11 @@ def bgmv_expand_slice(
|
|||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = True
|
add_inputs: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
bgmv_expand_slice
|
bgmv_expand_slice
|
||||||
"""
|
"""
|
||||||
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset)
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,16 +22,11 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional, Union, final
|
|
||||||
|
|
||||||
import torch
|
|
||||||
# Disable torchdynamo for all functions in this file
|
|
||||||
torch._dynamo.config.disable = True
|
|
||||||
|
|
||||||
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
from vllm_kunlun.lora.ops.kunlun_ops import (
|
from vllm_kunlun.lora.ops.kunlun_ops import (
|
||||||
bgmv_expand,
|
bgmv_expand,
|
||||||
@@ -42,7 +37,7 @@ from vllm_kunlun.lora.ops.kunlun_ops import (
|
|||||||
sgmv_shrink,
|
sgmv_shrink,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
# Disable torchdynamo for all functions in this file
|
||||||
|
|
||||||
|
|
||||||
# The platforms that are compatible with the PyTorch-native implementation can
|
# The platforms that are compatible with the PyTorch-native implementation can
|
||||||
@@ -545,4 +540,4 @@ class PunicaWrapperKunlun(PunicaWrapperBase):
|
|||||||
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
|
||||||
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
|
||||||
|
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
|
|||||||
@@ -14,39 +14,53 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# This file is a part of the vllm-kunlun project.
|
# This file is a part of the vllm-kunlun project.
|
||||||
#
|
#
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|
||||||
import xtorch_ops
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING
|
from itertools import accumulate
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
import torch
|
||||||
AttentionMetadata, AttentionLayer, AttentionType)
|
import xtorch_ops
|
||||||
|
from vllm.attention.backends.abstract import (
|
||||||
|
AttentionBackend,
|
||||||
|
AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionType,
|
||||||
|
)
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
AttentionCGSupport,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
split_decodes_and_prefills,
|
||||||
|
)
|
||||||
|
|
||||||
# from vllm.attention.backends.utils import CommonAttentionState
|
# from vllm.attention.backends.utils import CommonAttentionState
|
||||||
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
|
||||||
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata)
|
from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
|
||||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps
|
|
||||||
|
|
||||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
|
||||||
AttentionCGSupport,
|
|
||||||
split_decodes_and_prefills)
|
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
|
||||||
from itertools import accumulate
|
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
class KunlunAttentionBackend(AttentionBackend):
|
class KunlunAttentionBackend(AttentionBackend):
|
||||||
"""KunlunAttentionBackend"""
|
"""KunlunAttentionBackend"""
|
||||||
|
|
||||||
# crucial to cuda graph
|
# crucial to cuda graph
|
||||||
accept_output_buffer = True
|
accept_output_buffer = True
|
||||||
|
|
||||||
@@ -81,12 +95,13 @@ class KunlunAttentionBackend(AttentionBackend):
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
cache_dtype_str: str = "auto"
|
cache_dtype_str: str = "auto",
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
"""get_kv_cache_shape"""
|
"""get_kv_cache_shape"""
|
||||||
# return (2, num_blocks, block_size, num_kv_heads * head_size)
|
# return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
return PagedAttention.get_kv_cache_shape(
|
||||||
num_kv_heads, head_size)
|
num_blocks, block_size, num_kv_heads, head_size
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
@@ -104,13 +119,12 @@ class KunlunAttentionBackend(AttentionBackend):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""copy_blocks"""
|
"""copy_blocks"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
"""KunlunMetadata"""
|
"""KunlunMetadata"""
|
||||||
|
|
||||||
|
|
||||||
# |---------- N-1 iteration --------|
|
# |---------- N-1 iteration --------|
|
||||||
# |---------------- N iteration ---------------------|
|
# |---------------- N iteration ---------------------|
|
||||||
# |- tokenA -|......................|-- newTokens ---|
|
# |- tokenA -|......................|-- newTokens ---|
|
||||||
@@ -133,7 +147,7 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Cuda-graph is currently enabled for decoding only.
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
block_tables: torch.Tensor
|
block_tables: torch.Tensor
|
||||||
|
|
||||||
@@ -203,11 +217,13 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
use_cascade: Optional[bool] = False
|
use_cascade: Optional[bool] = False
|
||||||
|
|
||||||
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
seq_lens_tensor_cpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
num_prefill_tokens: int = 0
|
num_prefill_tokens: int = 0
|
||||||
num_decode_tokens: int = 0
|
num_decode_tokens: int = 0
|
||||||
num_prefills: int = 0
|
num_prefills: int = 0
|
||||||
num_decodes: int = 0
|
num_decodes: int = 0
|
||||||
|
is_speculative: Optional[bool] = False
|
||||||
|
max_model_len: int = 0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""__post_init__"""
|
"""__post_init__"""
|
||||||
@@ -218,16 +234,20 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
@property
|
@property
|
||||||
def is_all_encoder_attn_metadata_set(self):
|
def is_all_encoder_attn_metadata_set(self):
|
||||||
"""is_all_encoder_attn_metadata_set"""
|
"""is_all_encoder_attn_metadata_set"""
|
||||||
return ((self.encoder_seq_lens is not None)
|
return (
|
||||||
and (self.encoder_seq_lens_tensor is not None)
|
(self.encoder_seq_lens is not None)
|
||||||
and (self.max_encoder_seq_len is not None))
|
and (self.encoder_seq_lens_tensor is not None)
|
||||||
|
and (self.max_encoder_seq_len is not None)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_all_cross_attn_metadata_set(self):
|
def is_all_cross_attn_metadata_set(self):
|
||||||
"""is_all_cross_attn_metadata_set"""
|
"""is_all_cross_attn_metadata_set"""
|
||||||
return (self.is_all_encoder_attn_metadata_set
|
return (
|
||||||
and (self.cross_slot_mapping is not None)
|
self.is_all_encoder_attn_metadata_set
|
||||||
and (self.cross_block_tables is not None))
|
and (self.cross_slot_mapping is not None)
|
||||||
|
and (self.cross_block_tables is not None)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
def prefill_metadata(self) -> Optional["KunlunMetadata"]:
|
||||||
@@ -240,35 +260,60 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# metadata structure
|
# metadata structure
|
||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
assert ((self.seq_lens_tensor is not None)
|
assert (self.seq_lens_tensor is not None) or (
|
||||||
or (self.encoder_seq_lens_tensor is not None))
|
self.encoder_seq_lens_tensor is not None
|
||||||
|
)
|
||||||
|
|
||||||
# Compute some attn_metadata fields which default to None
|
# Compute some attn_metadata fields which default to None
|
||||||
query_start_loc = (None if self.query_start_loc is None else
|
query_start_loc = (
|
||||||
self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)])
|
None
|
||||||
|
if self.query_start_loc is None
|
||||||
|
else self.query_start_loc[-(self.num_prefills + 1) :]
|
||||||
|
- self.query_start_loc[-(self.num_prefills + 1)]
|
||||||
|
)
|
||||||
# flash attention needs both lod information on host and device
|
# flash attention needs both lod information on host and device
|
||||||
query_start_loc_host = (None if self.query_start_loc_host is None else
|
query_start_loc_host = (
|
||||||
self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)])
|
None
|
||||||
|
if self.query_start_loc_host is None
|
||||||
|
else self.query_start_loc_host[-(self.num_prefills + 1) :]
|
||||||
|
- self.query_start_loc_host[-(self.num_prefills + 1)]
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(chengruichang):how to support prefix cache
|
# TODO(chengruichang):how to support prefix cache
|
||||||
kv_prefix_start_loc_host = None
|
kv_prefix_start_loc_host = None
|
||||||
kv_prefix_start_loc = None
|
kv_prefix_start_loc = None
|
||||||
slot_mapping = (None if self.slot_mapping is None else
|
slot_mapping = (
|
||||||
self.slot_mapping[-self.num_prefill_tokens:])
|
None
|
||||||
|
if self.slot_mapping is None
|
||||||
|
else self.slot_mapping[-self.num_prefill_tokens :]
|
||||||
|
)
|
||||||
|
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
seq_lens_tensor = (
|
||||||
self.seq_lens_tensor[-self.num_prefills:])
|
None
|
||||||
seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:])
|
if self.seq_lens_tensor is None
|
||||||
|
else self.seq_lens_tensor[-self.num_prefills :]
|
||||||
|
)
|
||||||
|
seq_lens = (
|
||||||
|
None if self.seq_lens is None else self.seq_lens[-self.num_prefills :]
|
||||||
|
)
|
||||||
|
|
||||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
context_lens_tensor = (
|
||||||
self.context_lens_tensor[-self.num_prefills:])
|
None
|
||||||
|
if self.context_lens_tensor is None
|
||||||
block_tables = (None if self.block_tables is None else
|
else self.context_lens_tensor[-self.num_prefills :]
|
||||||
self.block_tables[-self.num_prefills:])
|
)
|
||||||
input_positions = (None if self.input_positions is None else
|
|
||||||
self.input_positions[-self.num_prefills:])
|
block_tables = (
|
||||||
|
None
|
||||||
|
if self.block_tables is None
|
||||||
|
else self.block_tables[-self.num_prefills :]
|
||||||
|
)
|
||||||
|
input_positions = (
|
||||||
|
None
|
||||||
|
if self.input_positions is None
|
||||||
|
else self.input_positions[-self.num_prefills :]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.kv_lod_cpu is None:
|
if self.kv_lod_cpu is None:
|
||||||
kv_lod_cpu = None
|
kv_lod_cpu = None
|
||||||
kv_lod_xpu = None
|
kv_lod_xpu = None
|
||||||
@@ -280,19 +325,17 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
base_xpu = self.kv_lod_xpu[start]
|
base_xpu = self.kv_lod_xpu[start]
|
||||||
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
|
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
|
||||||
|
|
||||||
|
|
||||||
# Construct & cache prefill-phase attention metadata structure
|
# Construct & cache prefill-phase attention metadata structure
|
||||||
self._cached_prefill_metadata = KunlunMetadata(
|
self._cached_prefill_metadata = KunlunMetadata(
|
||||||
num_actual_tokens=self.num_actual_tokens,
|
num_actual_tokens=self.num_actual_tokens,
|
||||||
multi_modal_placeholder_index_maps=self.
|
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||||
multi_modal_placeholder_index_maps,
|
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
seq_start_loc = None,
|
seq_start_loc=None,
|
||||||
kv_lod_cpu=kv_lod_cpu,
|
kv_lod_cpu=kv_lod_cpu,
|
||||||
kv_lod_xpu=kv_lod_xpu,
|
kv_lod_xpu=kv_lod_xpu,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
@@ -314,7 +357,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
cross_slot_mapping=self.cross_slot_mapping,
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
cross_block_tables=self.cross_block_tables,
|
cross_block_tables=self.cross_block_tables,
|
||||||
enable_kv_scales_calculation=False,
|
enable_kv_scales_calculation=False,
|
||||||
use_cascade=self.use_cascade)
|
use_cascade=self.use_cascade,
|
||||||
|
is_speculative=self.is_speculative,
|
||||||
|
)
|
||||||
return self._cached_prefill_metadata
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -327,40 +372,47 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Recover cached decode-phase attention
|
# Recover cached decode-phase attention
|
||||||
# metadata structure
|
# metadata structure
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
assert ((self.seq_lens_tensor is not None)
|
assert (self.seq_lens_tensor is not None) or (
|
||||||
or (self.encoder_seq_lens_tensor is not None))
|
self.encoder_seq_lens_tensor is not None
|
||||||
|
)
|
||||||
|
|
||||||
if self.num_prefills != 0:
|
if self.num_prefills != 0:
|
||||||
# Compute some attn_metadata fields which default to None
|
# Compute some attn_metadata fields which default to None
|
||||||
slot_mapping = (None if self.slot_mapping is None else
|
slot_mapping = (
|
||||||
self.slot_mapping[:-self.num_prefill_tokens])
|
None
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
if self.slot_mapping is None
|
||||||
self.seq_lens_tensor[:-self.num_prefills])
|
else self.slot_mapping[: -self.num_prefill_tokens]
|
||||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
)
|
||||||
self.seq_lens_tensor_cpu[:-self.num_prefills])
|
seq_lens_tensor = (
|
||||||
|
None
|
||||||
block_tables = (None if self.block_tables is None else
|
if self.seq_lens_tensor is None
|
||||||
self.block_tables[:-self.num_prefills])
|
else self.seq_lens_tensor[: -self.num_prefills]
|
||||||
|
)
|
||||||
|
seq_lens_tensor_cpu = (
|
||||||
|
None
|
||||||
|
if self.seq_lens_tensor_cpu is None
|
||||||
|
else self.seq_lens_tensor_cpu[: -self.num_prefills]
|
||||||
|
)
|
||||||
|
block_tables = (
|
||||||
|
None
|
||||||
|
if self.block_tables is None
|
||||||
|
else self.block_tables[: -self.num_prefills]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Compute some attn_metadata fields which default to None
|
# Compute some attn_metadata fields which default to None
|
||||||
slot_mapping = (None if self.slot_mapping is None else
|
slot_mapping = None if self.slot_mapping is None else self.slot_mapping
|
||||||
self.slot_mapping)
|
seq_lens_tensor = (
|
||||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
None if self.seq_lens_tensor is None else self.seq_lens_tensor
|
||||||
self.seq_lens_tensor)
|
)
|
||||||
|
seq_lens_tensor_cpu = (
|
||||||
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else
|
None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu
|
||||||
self.seq_lens_tensor_cpu)
|
)
|
||||||
|
block_tables = None if self.block_tables is None else self.block_tables
|
||||||
|
|
||||||
block_tables = (None if self.block_tables is None else
|
|
||||||
self.block_tables)
|
|
||||||
|
|
||||||
|
|
||||||
# Construct & cache decode-phase attention metadata structure
|
# Construct & cache decode-phase attention metadata structure
|
||||||
self._cached_decode_metadata = KunlunMetadata(
|
self._cached_decode_metadata = KunlunMetadata(
|
||||||
num_actual_tokens=self.num_actual_tokens,
|
num_actual_tokens=self.num_actual_tokens,
|
||||||
multi_modal_placeholder_index_maps=self.
|
multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
|
||||||
multi_modal_placeholder_index_maps,
|
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
@@ -378,19 +430,29 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
cross_slot_mapping=self.cross_slot_mapping,
|
cross_slot_mapping=self.cross_slot_mapping,
|
||||||
cross_block_tables=self.cross_block_tables,
|
cross_block_tables=self.cross_block_tables,
|
||||||
enable_kv_scales_calculation=False,
|
enable_kv_scales_calculation=False,
|
||||||
use_cascade=self.use_cascade)
|
use_cascade=self.use_cascade,
|
||||||
|
is_speculative=self.is_speculative,
|
||||||
|
max_model_len=self.max_model_len,
|
||||||
|
)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
|
M = TypeVar("M")
|
||||||
|
|
||||||
|
|
||||||
class KunlunAttentionMetadataBuilder:
|
class KunlunAttentionMetadataBuilder:
|
||||||
"""KunlunAttentionMetadataBuilder"""
|
"""KunlunAttentionMetadataBuilder"""
|
||||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|
||||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
reorder_batch_threshold: ClassVar[Optional[int]] = 1
|
reorder_batch_threshold: ClassVar[Optional[int]] = 1
|
||||||
|
|
||||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
def __init__(
|
||||||
vllm_config: VllmConfig, device: torch.device):
|
self,
|
||||||
|
kv_cache_spec: AttentionSpec,
|
||||||
|
layer_names: list[str],
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
"""__init__"""
|
"""__init__"""
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
@@ -398,17 +460,45 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
|
|
||||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config)
|
self.parallel_config
|
||||||
self.num_heads_kv = self.model_config.get_num_kv_heads(
|
)
|
||||||
self.parallel_config)
|
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
self.headdim = self.model_config.get_head_size()
|
self.headdim = self.model_config.get_head_size()
|
||||||
|
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def reorder_batch(self, input_batch: "InputBatch",
|
def _init_reorder_batch_threshold(
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
self,
|
||||||
|
reorder_batch_threshold: int | None = 1,
|
||||||
|
supports_spec_as_decode: bool = False,
|
||||||
|
supports_dcp_with_varlen: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.reorder_batch_threshold = reorder_batch_threshold
|
||||||
|
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||||
|
# If the backend supports spec-as-decode kernels, then we can set
|
||||||
|
# the reorder_batch_threshold based on the number of speculative
|
||||||
|
# tokens from the config.
|
||||||
|
speculative_config = self.vllm_config.speculative_config
|
||||||
|
if (
|
||||||
|
speculative_config is not None
|
||||||
|
and speculative_config.num_speculative_tokens is not None
|
||||||
|
):
|
||||||
|
self.reorder_batch_threshold = max(
|
||||||
|
self.reorder_batch_threshold,
|
||||||
|
1 + speculative_config.num_speculative_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||||
|
and not supports_dcp_with_varlen
|
||||||
|
):
|
||||||
|
self.reorder_batch_threshold = 1
|
||||||
|
|
||||||
|
def reorder_batch(
|
||||||
|
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
||||||
|
) -> bool:
|
||||||
"""reorder_batch"""
|
"""reorder_batch"""
|
||||||
decodes = []
|
decodes = []
|
||||||
prefills = []
|
prefills = []
|
||||||
@@ -432,8 +522,9 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
|
|
||||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||||
if decodes[num_decodes - i] >= num_decodes:
|
if decodes[num_decodes - i] >= num_decodes:
|
||||||
input_batch.swap_states(prefills[first_prefill],
|
input_batch.swap_states(
|
||||||
decodes[num_decodes - i])
|
prefills[first_prefill], decodes[num_decodes - i]
|
||||||
|
)
|
||||||
first_prefill += 1
|
first_prefill += 1
|
||||||
modified_batch = True
|
modified_batch = True
|
||||||
else:
|
else:
|
||||||
@@ -443,7 +534,7 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
self._num_decode_tokens = num_decode_tokens
|
self._num_decode_tokens = num_decode_tokens
|
||||||
self._num_prefill_tokens = num_prefill_tokens
|
self._num_prefill_tokens = num_prefill_tokens
|
||||||
return modified_batch
|
return modified_batch
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
) -> KunlunMetadata:
|
) -> KunlunMetadata:
|
||||||
@@ -454,8 +545,30 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
attn_metadata.seq_lens_tensor.fill_(1)
|
attn_metadata.seq_lens_tensor.fill_(1)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def build(self, common_prefix_len: int,
|
def build_for_drafting(
|
||||||
common_attn_metadata: CommonAttentionMetadata):
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
draft_index: int,
|
||||||
|
) -> M:
|
||||||
|
"""
|
||||||
|
Build attention metadata for draft model. Uses build by default.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
common_attn_metadata: The common attention metadata.
|
||||||
|
draft_index: The index of the current draft operation.
|
||||||
|
When speculating a chain of tokens, this index refers to the
|
||||||
|
draft attempt for the i-th token.
|
||||||
|
For tree-based attention, this index instead refers to the
|
||||||
|
draft attempt for the i-th level in the tree of tokens.
|
||||||
|
"""
|
||||||
|
return self.build(
|
||||||
|
common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata
|
||||||
|
):
|
||||||
"""build"""
|
"""build"""
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
@@ -464,30 +577,38 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
|
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]
|
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to(
|
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||||
|
|
||||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||||
|
|
||||||
|
|
||||||
|
seq_start_loc_tensor = torch.empty(
|
||||||
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device)
|
len(seq_start_loc), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
|
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
|
||||||
|
|
||||||
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
|
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
|
||||||
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
|
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
|
||||||
kv_lod_xpu = kv_lod_cpu.to(self.device)
|
kv_lod_xpu = kv_lod_cpu.to(self.device)
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
|
||||||
split_decodes_and_prefills(common_attn_metadata)
|
|
||||||
|
|
||||||
num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1])
|
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
|
split_decodes_and_prefills(
|
||||||
|
common_attn_metadata,
|
||||||
|
decode_threshold=self.reorder_batch_threshold or 1,
|
||||||
|
require_uniform=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_scheduled_tokens = np.diff(
|
||||||
|
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||||
|
)
|
||||||
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
|
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
|
||||||
|
|
||||||
if num_decode_tokens == 0:
|
if num_decode_tokens == 0:
|
||||||
@@ -495,18 +616,19 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
else:
|
else:
|
||||||
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
|
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
|
||||||
|
|
||||||
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs]
|
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes:num_reqs]
|
||||||
|
|
||||||
if num_prefill_tokens == 0:
|
if num_prefill_tokens == 0:
|
||||||
max_prefill_seq_len = 0
|
max_prefill_seq_len = 0
|
||||||
else:
|
else:
|
||||||
max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens)
|
max_prefill_seq_len = np.max(tmp_prefill_scheduled_tokens)
|
||||||
|
|
||||||
use_cascade = False
|
use_cascade = False
|
||||||
|
|
||||||
attn_metadata = KunlunMetadata(
|
attn_metadata = KunlunMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
|
num_decodes=num_decodes,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=None,
|
multi_modal_placeholder_index_maps=None,
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
@@ -525,11 +647,14 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
block_tables=block_table_tensor,
|
block_tables=block_table_tensor,
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
|
is_speculative=self.reorder_batch_threshold > 1,
|
||||||
|
max_model_len=self.vllm_config.model_config.max_model_len,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def can_run_in_cudagraph(
|
def can_run_in_cudagraph(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
|
) -> bool:
|
||||||
"""can_run_in_cudagraph"""
|
"""can_run_in_cudagraph"""
|
||||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||||
return True
|
return True
|
||||||
@@ -538,6 +663,7 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
"""use_cascade_attention"""
|
"""use_cascade_attention"""
|
||||||
return use_cascade_attention(*args, **kwargs)
|
return use_cascade_attention(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||||
"""KunlunAttentionImpl"""
|
"""KunlunAttentionImpl"""
|
||||||
|
|
||||||
@@ -555,13 +681,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
sinks:Optional[torch.Tensor]= None,
|
sinks: Optional[torch.Tensor] = None,
|
||||||
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None,
|
multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""__init__"""
|
"""__init__"""
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError("kunlunAttention does not support block-sparse attention.")
|
||||||
"kunlunAttention does not support block-sparse attention.")
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
@@ -582,15 +707,17 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
if head_size not in suppored_head_sizes:
|
if head_size not in suppored_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
f"Supported head sizes are: {suppored_head_sizes}."
|
||||||
|
)
|
||||||
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
assert sinks.shape[0] == num_heads, (
|
assert sinks.shape[0] == num_heads, (
|
||||||
"Sinks must have the same number of heads as the number of "
|
"Sinks must have the same number of heads as the number of "
|
||||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||||
f"num_heads: {num_heads}.")
|
f"num_heads: {num_heads}."
|
||||||
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
|
)
|
||||||
|
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -605,7 +732,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
output_scale: Optional[torch.Tensor] = None,
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
output_block_scale: Optional[torch.Tensor] = None
|
output_block_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""forward"""
|
"""forward"""
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
@@ -624,7 +751,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
# Self-attention vs. cross-attention will impact
|
# Self-attention vs. cross-attention will impact
|
||||||
# which KV cache memory-mapping & which
|
# which KV cache memory-mapping & which
|
||||||
# seqlen datastructures we utilize
|
# seqlen datastructures we utilize
|
||||||
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0):
|
if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
|
||||||
# KV-cache during decoder-self- or
|
# KV-cache during decoder-self- or
|
||||||
# encoder-decoder-cross-attention, but not
|
# encoder-decoder-cross-attention, but not
|
||||||
# during encoder attention.
|
# during encoder attention.
|
||||||
@@ -633,7 +760,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
# we still need to break out key_cache and value_cache
|
# we still need to break out key_cache and value_cache
|
||||||
# i.e. for later use by paged attention
|
# i.e. for later use by paged attention
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
|
)
|
||||||
|
|
||||||
if (key is not None) and (value is not None):
|
if (key is not None) and (value is not None):
|
||||||
updated_slot_mapping = attn_metadata.slot_mapping
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
@@ -644,11 +772,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
value = value.contiguous()
|
value = value.contiguous()
|
||||||
if key_cache.is_contiguous():
|
if key_cache.is_contiguous():
|
||||||
xtorch_ops.reshape_and_cache(
|
xtorch_ops.reshape_and_cache(
|
||||||
key,
|
key[: attn_metadata.num_actual_tokens],
|
||||||
value,
|
value[: attn_metadata.num_actual_tokens],
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
updated_slot_mapping)
|
updated_slot_mapping,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||||
@@ -657,7 +786,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
value,
|
value,
|
||||||
cast_key_cache,
|
cast_key_cache,
|
||||||
cast_value_cache,
|
cast_value_cache,
|
||||||
updated_slot_mapping)
|
updated_slot_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
assert attn_type == AttentionType.DECODER
|
assert attn_type == AttentionType.DECODER
|
||||||
# Decoder self-attention supports chunked prefill.
|
# Decoder self-attention supports chunked prefill.
|
||||||
@@ -668,88 +798,98 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens]
|
prefill_query = query[num_decode_tokens : attn_metadata.num_actual_tokens]
|
||||||
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens]
|
prefill_key = key[num_decode_tokens : attn_metadata.num_actual_tokens]
|
||||||
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens]
|
prefill_value = value[num_decode_tokens : attn_metadata.num_actual_tokens]
|
||||||
|
|
||||||
# For hybrid Attention (Qwen3-Next.)
|
# For hybrid Attention (Qwen3-Next.)
|
||||||
if key_cache.is_contiguous():
|
if key_cache.is_contiguous():
|
||||||
tmp_block_tables = prefill_meta.block_tables
|
tmp_block_tables = prefill_meta.block_tables
|
||||||
else:
|
else:
|
||||||
# For hybrid Attention (Qwen3-Next)
|
# For hybrid Attention (Qwen3-Next)
|
||||||
tmp_block_tables = prefill_meta.block_tables * 2
|
tmp_block_tables = prefill_meta.block_tables * 2
|
||||||
|
|
||||||
# Prefix cache
|
# Prefix cache
|
||||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||||
xtorch_ops.prefill_attention(
|
xtorch_ops.prefill_attention(
|
||||||
q=prefill_query,
|
q=prefill_query,
|
||||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
is_prefix_cache=True,
|
is_prefix_cache=True,
|
||||||
block_table=tmp_block_tables,
|
block_table=tmp_block_tables,
|
||||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||||
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
|
||||||
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
softmax_lse=None
|
softmax_lse=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
xtorch_ops.prefill_attention(
|
xtorch_ops.prefill_attention(
|
||||||
q=prefill_query,
|
q=prefill_query,
|
||||||
k=prefill_key,
|
k=prefill_key,
|
||||||
v=prefill_value,
|
v=prefill_value,
|
||||||
out=output[num_decode_tokens:attn_metadata.num_actual_tokens],
|
out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
|
||||||
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
context_qlen_lod_xpu=prefill_meta.query_start_loc,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
softmax_lse=None,
|
softmax_lse=None,
|
||||||
swa_left = self.sliding_window if self.sliding_window is not None else -1,
|
swa_left=(
|
||||||
swa_right = 0 if self.sliding_window is not None else -1,
|
self.sliding_window if self.sliding_window is not None else -1
|
||||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
),
|
||||||
|
swa_right=0 if self.sliding_window is not None else -1,
|
||||||
|
sink=(
|
||||||
|
self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
assert (
|
||||||
assert attn_type != AttentionType.ENCODER_ONLY, (
|
attn_type != AttentionType.ENCODER_ONLY
|
||||||
"Encoder-only models should not have decode metadata.")
|
), "Encoder-only models should not have decode metadata."
|
||||||
decode_query = query[:num_decode_tokens]
|
decode_query = query[:num_decode_tokens]
|
||||||
|
|
||||||
# For hybrid Attention (Qwen3-Next
|
# For hybrid Attention (Qwen3-Next
|
||||||
if key_cache.is_contiguous():
|
if key_cache.is_contiguous():
|
||||||
tmp_block_tables = decode_meta.block_tables
|
tmp_block_tables = decode_meta.block_tables
|
||||||
else:
|
else:
|
||||||
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next
|
tmp_block_tables = (
|
||||||
|
decode_meta.block_tables * 2
|
||||||
|
) # only test in Qwen3-Next
|
||||||
|
|
||||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
sig = inspect.signature(xtorch_ops.speculative_attention)
|
||||||
if "max_window_size" in sig.parameters:
|
if "max_window_size" in sig.parameters:
|
||||||
xtorch_ops.speculative_attention(
|
xtorch_ops.speculative_attention(
|
||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
# Only MLA support q len > 1 right now
|
# Only MLA support q len > 1 right now
|
||||||
q=decode_query.unsqueeze(0),
|
q=decode_query.unsqueeze(0),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||||
context_lens_xpu=decode_meta.seq_lens_tensor,
|
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||||
batch_num=decode_meta.block_tables.shape[0],
|
batch_num=decode_meta.block_tables.shape[0],
|
||||||
# TODO (@xyDong23): Support MTP(q lens >1)
|
# TODO (@xyDong23): Support MTP(q lens >1)
|
||||||
qlen=1,
|
qlen=1,
|
||||||
# TODO (@xyDong23): Support max_context_len to (262144)
|
# TODO (@xyDong23): Support max_context_len to (262144)
|
||||||
max_context_len=131072,
|
max_context_len=131072,
|
||||||
head_num=self.num_heads,
|
head_num=self.num_heads,
|
||||||
head_dim=self.head_size,
|
head_dim=self.head_size,
|
||||||
scale=0.0,
|
scale=0.0,
|
||||||
kv_head_num=self.num_kv_heads,
|
kv_head_num=self.num_kv_heads,
|
||||||
block_size=key_cache.shape[2],
|
block_size=key_cache.shape[2],
|
||||||
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||||
max_window_size=self.sliding_window if self.sliding_window is not None else -1,
|
max_window_size=(
|
||||||
block_tables=tmp_block_tables,
|
self.sliding_window if self.sliding_window is not None else -1
|
||||||
sink = self.sinks.to(torch.float32) if self.sinks is not None else None
|
),
|
||||||
|
block_tables=tmp_block_tables,
|
||||||
|
sink=(
|
||||||
|
self.sinks.to(torch.float32) if self.sinks is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
elif not attn_metadata.is_speculative:
|
||||||
xtorch_ops.paged_attention(
|
xtorch_ops.paged_attention(
|
||||||
x=decode_query,
|
x=decode_query,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
@@ -760,10 +900,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
is_context=False,
|
is_context=False,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
vo_head_dim=self.head_size
|
vo_head_dim=self.head_size,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
batch_size = attn_metadata.num_decodes
|
||||||
|
query_seq_len, head_num, head_dim = decode_query.shape
|
||||||
|
assert query_seq_len % batch_size == 0
|
||||||
|
qlen = query_seq_len // batch_size
|
||||||
|
out = output[:num_decode_tokens]
|
||||||
|
assert out.is_contiguous()
|
||||||
|
|
||||||
|
xtorch_ops.speculative_attention(
|
||||||
|
out=out.view(batch_size, qlen, head_num, self.head_size),
|
||||||
|
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
|
||||||
|
context_lens_xpu=decode_meta.seq_lens_tensor,
|
||||||
|
batch_num=batch_size,
|
||||||
|
qlen=qlen,
|
||||||
|
max_context_len=decode_meta.max_model_len,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
head_dim=self.head_size,
|
||||||
|
scale=0.0,
|
||||||
|
kv_head_num=self.num_kv_heads,
|
||||||
|
block_size=key_cache.shape[2],
|
||||||
|
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
|
||||||
|
block_tables=tmp_block_tables,
|
||||||
|
)
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(-1, self.num_heads * self.head_size)
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
|
||||||
def use_cascade_attention(
|
def use_cascade_attention(
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
query_lens: np.ndarray,
|
query_lens: np.ndarray,
|
||||||
@@ -785,7 +953,7 @@ def use_cascade_attention(
|
|||||||
# NOTE(woosuk): This is the common case. We should return False as soon as
|
# NOTE(woosuk): This is the common case. We should return False as soon as
|
||||||
# possible to avoid any unnecessary computation.
|
# possible to avoid any unnecessary computation.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if common_prefix_len < 256:
|
if common_prefix_len < 256:
|
||||||
return False
|
return False
|
||||||
# Cascade attention is currently not supported with these variants.
|
# Cascade attention is currently not supported with these variants.
|
||||||
@@ -803,8 +971,12 @@ def use_cascade_attention(
|
|||||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||||
# The criteria for using FlashDecoding can be found in the following link:
|
# The criteria for using FlashDecoding can be found in the following link:
|
||||||
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
|
||||||
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
|
use_flash_decoding = (
|
||||||
and not use_alibi and np.all(query_lens == 1))
|
num_queries_per_kv > 1
|
||||||
|
and not use_sliding_window
|
||||||
|
and not use_alibi
|
||||||
|
and np.all(query_lens == 1)
|
||||||
|
)
|
||||||
if not use_flash_decoding:
|
if not use_flash_decoding:
|
||||||
# Use cascade attention.
|
# Use cascade attention.
|
||||||
return True
|
return True
|
||||||
@@ -826,10 +998,11 @@ def use_cascade_attention(
|
|||||||
cascade_waves = cdiv(cascade_ctas, num_sms)
|
cascade_waves = cdiv(cascade_ctas, num_sms)
|
||||||
cascade_time = cascade_waves * num_prefix_tiles
|
cascade_time = cascade_waves * num_prefix_tiles
|
||||||
|
|
||||||
flash_decoding_ctas = (num_reqs * num_kv_heads *
|
flash_decoding_ctas = (
|
||||||
cdiv(num_queries_per_kv, q_tile_size))
|
num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
|
||||||
|
)
|
||||||
flash_decoding_ctas *= num_prefix_tiles
|
flash_decoding_ctas *= num_prefix_tiles
|
||||||
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
|
||||||
|
|
||||||
# Use cascade attention if it is faster than FlashDecoding.
|
# Use cascade attention if it is faster than FlashDecoding.
|
||||||
return cascade_time < flash_decoding_time
|
return cascade_time < flash_decoding_time
|
||||||
|
|||||||
@@ -1,16 +1,24 @@
|
|||||||
"""vllm_utils_wrapper.py"""
|
"""vllm_utils_wrapper.py"""
|
||||||
|
|
||||||
import vllm.distributed.parallel_state as parallel_state
|
|
||||||
import vllm.utils as _orig
|
|
||||||
from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple
|
|
||||||
from types import SimpleNamespace
|
|
||||||
import torch
|
|
||||||
from torch.library import Library
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import socket
|
||||||
import typing
|
import typing
|
||||||
from torch.library import register_fake
|
from types import SimpleNamespace
|
||||||
import vllm_kunlun._kunlun
|
from typing import Any, Callable, List, Optional, Tuple, Union, get_args, get_origin
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm.distributed.parallel_state as parallel_state
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.utils as _orig
|
||||||
|
from torch.library import Library, register_fake
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vllm_kunlun._kunlun # noqa: F401
|
||||||
|
except ImportError as e:
|
||||||
|
try:
|
||||||
|
from . import _kunlun # noqa: F401, F403
|
||||||
|
except ImportError:
|
||||||
|
print(f"Warning: Failed to load vllm_kunlun native extension: {e}")
|
||||||
|
|
||||||
|
|
||||||
def patch_annotations_for_schema(func):
|
def patch_annotations_for_schema(func):
|
||||||
@@ -87,7 +95,7 @@ def direct_register_custom_op(
|
|||||||
import torch.library
|
import torch.library
|
||||||
|
|
||||||
if hasattr(torch.library, "infer_schema"):
|
if hasattr(torch.library, "infer_schema"):
|
||||||
patched_func = patch_annotations_for_schema(op_func)
|
patch_annotations_for_schema(op_func)
|
||||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||||
else:
|
else:
|
||||||
# for pytorch 2.4
|
# for pytorch 2.4
|
||||||
@@ -153,7 +161,7 @@ _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
|
|||||||
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
|
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
|
||||||
_wrapped._get_open_port = _get_open_port
|
_wrapped._get_open_port = _get_open_port
|
||||||
|
|
||||||
import sys
|
import sys # noqa: E402
|
||||||
|
|
||||||
sys.modules["vllm.utils"] = _wrapped
|
sys.modules["vllm.utils"] = _wrapped
|
||||||
|
|
||||||
@@ -204,11 +212,10 @@ parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce
|
|||||||
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
|
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
|
||||||
|
|
||||||
|
|
||||||
from torch.library import custom_op, impl
|
from typing import Optional # noqa: E402
|
||||||
import torch
|
|
||||||
from vllm import _custom_ops as ops
|
import torch # noqa: E402
|
||||||
from typing import Optional, List
|
from torch.library import custom_op, impl # noqa: E402
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
@custom_op("_C::rms_norm", mutates_args=())
|
@custom_op("_C::rms_norm", mutates_args=())
|
||||||
@@ -379,9 +386,9 @@ def silu_and_mul_quant_xpu(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch # noqa: E402
|
||||||
import xtorch_ops
|
import xtorch_ops # noqa: E402
|
||||||
from torch.library import custom_op, impl
|
from torch.library import custom_op, impl # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
@custom_op("_C::add_rmsnorm", mutates_args=())
|
@custom_op("_C::add_rmsnorm", mutates_args=())
|
||||||
@@ -472,7 +479,7 @@ def rmsnorm_cuda(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def _fake_rmsnorm(
|
def _fake_rmsnorm(
|
||||||
@@ -618,7 +625,6 @@ split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox)
|
|||||||
|
|
||||||
# register fake op impl here
|
# register fake op impl here
|
||||||
# for torch.dynamo
|
# for torch.dynamo
|
||||||
from torch.library import register_fake
|
|
||||||
|
|
||||||
if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
||||||
|
|
||||||
@@ -1396,7 +1402,7 @@ def awq_dequantize_cuda(
|
|||||||
device=qweight.device,
|
device=qweight.device,
|
||||||
)
|
)
|
||||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||||
out = xtorch_ops.awq_dequantize(
|
xtorch_ops.awq_dequantize(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
zeros=zeros,
|
zeros=zeros,
|
||||||
@@ -1915,7 +1921,7 @@ def apply_repetition_penalties_(
|
|||||||
|
|
||||||
|
|
||||||
@impl("_C::apply_repetition_penalties_", "CUDA")
|
@impl("_C::apply_repetition_penalties_", "CUDA")
|
||||||
def apply_repetition_penalties_(
|
def apply_repetition_penalties_cuda(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
prompt_mask: torch.Tensor,
|
prompt_mask: torch.Tensor,
|
||||||
output_mask: torch.Tensor,
|
output_mask: torch.Tensor,
|
||||||
@@ -2341,34 +2347,499 @@ dequant_int4.register_fake(_fake_dequant_int4)
|
|||||||
##################################################
|
##################################################
|
||||||
@custom_op("_C::fast_topkv2", mutates_args=())
|
@custom_op("_C::fast_topkv2", mutates_args=())
|
||||||
def fast_topkv2(
|
def fast_topkv2(
|
||||||
score: torch.Tensor,
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
lengths: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = xtorch_ops.fast_topkv2(
|
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||||
score=score,
|
|
||||||
lengths=lengths,
|
|
||||||
topk=topk)
|
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
@impl("_C::fast_topkv2", "CUDA")
|
@impl("_C::fast_topkv2", "CUDA")
|
||||||
def fast_topkv2_cuda(
|
def fast_topkv2_cuda(
|
||||||
score: torch.Tensor,
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
lengths: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = xtorch_ops.fast_topkv2(
|
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||||
score=score,
|
|
||||||
lengths=lengths,
|
|
||||||
topk=topk)
|
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
def _fake_fast_topkv2(
|
def _fake_fast_topkv2(
|
||||||
score: torch.Tensor,
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
lengths: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
fast_topkv2.register_fake(_fake_fast_topkv2)
|
|
||||||
|
fast_topkv2.register_fake(_fake_fast_topkv2)
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------------- LoRA ops --------------------
|
||||||
|
##################################################
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- sgmv_shrink_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sgmv_shrink_lora", mutates_args=())
|
||||||
|
def sgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
|
||||||
|
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||||
|
inputs,
|
||||||
|
lora_a_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::sgmv_shrink_lora", "CUDA")
|
||||||
|
def sgmv_shrink_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
|
||||||
|
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
||||||
|
inputs,
|
||||||
|
lora_a_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_sgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
scaling: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_shrink_lora.register_fake(_fake_sgmv_shrink_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- sgmv_expand_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sgmv_expand_lora", mutates_args=())
|
||||||
|
def sgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
|
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::sgmv_expand_lora", "CUDA")
|
||||||
|
def sgmv_expand_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
# )
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
||||||
|
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_sgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_expand_lora.register_fake(_fake_sgmv_expand_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- sgmv_expand_slice_lora -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::sgmv_expand_slice_lora", mutates_args=())
|
||||||
|
def sgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
inputs,
|
||||||
|
lora_b_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
slice_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::sgmv_expand_slice_lora", "CUDA")
|
||||||
|
def sgmv_expand_slice_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
||||||
|
inputs,
|
||||||
|
lora_b_weights,
|
||||||
|
seq_len_tensor,
|
||||||
|
lora_indices_tensor,
|
||||||
|
output_tensor,
|
||||||
|
slice_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_sgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
batches: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
token_nums: int,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_expand_slice_lora.register_fake(_fake_sgmv_expand_slice_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- bgmv_shrink_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::bgmv_shrink_lora", mutates_args=())
|
||||||
|
def bgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||||
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::bgmv_shrink_lora", "CUDA")
|
||||||
|
def bgmv_shrink_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
||||||
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_bgmv_shrink_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_a_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
expert_m: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
scaling: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_shrink_lora.register_fake(_fake_bgmv_shrink_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# -------------- bgmv_expand_lora ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::bgmv_expand_lora", mutates_args=())
|
||||||
|
def bgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::bgmv_expand_lora", "CUDA")
|
||||||
|
def bgmv_expand_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_bgmv_expand_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_expand_lora.register_fake(_fake_bgmv_expand_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- bgmv_expand_slice_lora -------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::bgmv_expand_slice_lora", mutates_args=())
|
||||||
|
def bgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::bgmv_expand_slice_lora", "CUDA")
|
||||||
|
def bgmv_expand_slice_lora_cuda(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
||||||
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_bgmv_expand_slice_lora(
|
||||||
|
inputs: torch.Tensor,
|
||||||
|
lora_b_weights: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
block_statistic: torch.Tensor,
|
||||||
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
|
moe_index: torch.Tensor,
|
||||||
|
normed_scale: torch.Tensor,
|
||||||
|
lora_indices_tensor: torch.Tensor,
|
||||||
|
slice_offset: int,
|
||||||
|
slice_size: int,
|
||||||
|
add_inputs: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_expand_slice_lora.register_fake(_fake_bgmv_expand_slice_lora)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################
|
||||||
|
# ----------- lora_matmul_inplace ----------------
|
||||||
|
##################################################
|
||||||
|
@custom_op("_C::lora_matmul_inplace", mutates_args=())
|
||||||
|
def lora_matmul_inplace(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
x_trans: bool = False,
|
||||||
|
w_trans: bool = True,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
beta: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.matmul(
|
||||||
|
x=x.contiguous(),
|
||||||
|
w=w.contiguous(),
|
||||||
|
out=output_tensor,
|
||||||
|
x_trans=x_trans,
|
||||||
|
w_trans=w_trans,
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@impl("_C::lora_matmul_inplace", "CUDA")
|
||||||
|
def lora_matmul_inplace_cuda(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
x_trans: bool = False,
|
||||||
|
w_trans: bool = True,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
beta: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
xtorch_ops.matmul(
|
||||||
|
x=x.contiguous(),
|
||||||
|
w=w.contiguous(),
|
||||||
|
out=output_tensor,
|
||||||
|
x_trans=x_trans,
|
||||||
|
w_trans=w_trans,
|
||||||
|
alpha=alpha,
|
||||||
|
beta=beta,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_lora_matmul_inplace(
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
x_trans: bool = False,
|
||||||
|
w_trans: bool = True,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
beta: float = 1.0,
|
||||||
|
) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
lora_matmul_inplace.register_fake(_fake_lora_matmul_inplace)
|
||||||
|
|||||||
Reference in New Issue
Block a user