Compare commits

..

7 Commits

Author SHA1 Message Date
starkwj
cea31d16fb add readme 2026-02-12 11:13:26 +08:00
starkwj
01bafad6d0 add vxpu 2026-02-12 11:08:07 +08:00
Xinyu Dong
070bfa4a73 [Bugfix] Fixed Kunlun Graph Failed (#193)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
2026-02-11 18:52:18 +08:00
fromck
fc48b79ae9 support glm4.7 mtp (#187)
Signed-off-by: chengxiaokang <chengxiaokang@baidu.com>
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
2026-02-11 18:32:30 +08:00
WANG HAO
bd8c999335 Further optimize multi-lora inference,LoRA-enabled performance achieves 80%+ of non-LoRA performance (#190)
* optimize lora inference

Signed-off-by: wanghao <wanghao@example.com>

* further optimize multi-lora inference,LoRA-enabled performance achieves 80%+ of non-LoRA performance

Signed-off-by: wanghao <wanghao@example.com>

---------

Signed-off-by: wanghao <wanghao@example.com>
Co-authored-by: wanghao <wanghao@example.com>
2026-02-11 12:04:14 +08:00
WeiJie_Hong
9b1f25fbe3 [Doc] update xspeedgate_ops (20260130) (#188)
Signed-off-by: WeiJie_Hong <1462519292@qq.com>
2026-02-10 18:05:20 +08:00
WeiJie_Hong
42c7ef2f27 [Doc] add DeepSeek-V3.2-Exp-w8a8 to installation.md and tutorials (#186)
Signed-off-by: WeiJie_Hong <1462519292@qq.com>
2026-02-10 17:18:32 +08:00
9 changed files with 1235 additions and 414 deletions

View File

@@ -1,76 +1,76 @@
name: e2e-test # name: e2e-test
on: # on:
workflow_call: # workflow_call:
pull_request: # pull_request:
branches: [main] # branches: [main]
types: [opened, synchronize, reopened] # types: [opened, synchronize, reopened]
push: # push:
branches: [main] # branches: [main]
concurrency: # concurrency:
group: e2e-singlecard # group: e2e-singlecard
cancel-in-progress: false # cancel-in-progress: false
jobs: # jobs:
e2e: # e2e:
name: e2e-test-singlecard # name: e2e-test-singlecard
runs-on: # runs-on:
- self-hosted # - self-hosted
- Linux # - Linux
- X64 # - X64
steps: # steps:
- name: Checkout PR code # - name: Checkout PR code
uses: actions/checkout@v4 # uses: actions/checkout@v4
with: # with:
fetch-depth: 0 # fetch-depth: 0
- name: Verify PR workspace # - name: Verify PR workspace
run: | # run: |
echo "===== WORKSPACE =====" # echo "===== WORKSPACE ====="
pwd # pwd
ls -l # ls -l
echo "===== GIT INFO =====" # echo "===== GIT INFO ====="
git rev-parse HEAD # git rev-parse HEAD
git log -1 --oneline # git log -1 --oneline
git status --porcelain # git status --porcelain
- name: Start docker # - name: Start docker
run: | # run: |
bash ci/scripts/docker/start_docker.sh # bash ci/scripts/docker/start_docker.sh
- name: Install enviroments # - name: Install enviroments
env: # env:
PROXY_URL: ${{ secrets.PROXY_URL }} # PROXY_URL: ${{ secrets.PROXY_URL }}
NO_PROXY_LIST: ${{ secrets.NO_PROXY_LIST }} # NO_PROXY_LIST: ${{ secrets.NO_PROXY_LIST }}
run: | # run: |
bash ci/scripts/env/install_env.sh # bash ci/scripts/env/install_env.sh
- name: Start vLLM server # - name: Start vLLM server
run: | # run: |
bash ci/scripts/server/start_vllm.sh # bash ci/scripts/server/start_vllm.sh
- name: Wait for vLLM ready # - name: Wait for vLLM ready
run: | # run: |
bash ci/scripts/server/wait_vllm.sh # bash ci/scripts/server/wait_vllm.sh
- name: API Test # - name: API Test
run: | # run: |
docker exec aiak-e2e-singlecard bash -lc ' # docker exec aiak-e2e-singlecard bash -lc '
curl http://localhost:8356/v1/chat/completions \ # curl http://localhost:8356/v1/chat/completions \
-H "Content-Type: application/json" \ # -H "Content-Type: application/json" \
-d @- << "EOF" # -d @- << "EOF"
{ # {
"model": "Qwen3-8B", # "model": "Qwen3-8B",
"messages": [ # "messages": [
{ "role": "user", "content": "Who are you?" } # { "role": "user", "content": "Who are you?" }
], # ],
"max_tokens": 200, # "max_tokens": 200,
"temperature": 0 # "temperature": 0
} # }
EOF # EOF
' # '
# - name: Accuracy testing # - name: Accuracy testing
# run: | # run: |

View File

@@ -11,7 +11,9 @@ This document describes how to install vllm-kunlun manually.
- vLLM (same version as vllm-kunlun) - vLLM (same version as vllm-kunlun)
## Setup environment using container ## Setup environment using container
We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:base_v0.0.2` and `wjie520/vllm_kunlun:base_mimo_v0.0.2`(Only MIMO_V2 and GPT-OSS).You can pull it using the `docker pull` command. We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:uv_base`.You can pull it using the `docker pull wjie520/vllm_kunlun:uv_base` command.
We also provide images with xpytorch and ops installed.You can pull it using the `wjie520/vllm_kunlun:base_v0.0.2 and wjie520/vllm_kunlun:base_mimo_v0.0.2 (Only MIMO_V2 and GPT-OSS)` command
### Container startup script ### Container startup script
:::::{tab-set} :::::{tab-set}
@@ -19,9 +21,8 @@ We provide a clean, minimal base image for your use`wjie520/vllm_kunlun:base_v0.
::::{tab-item} start_docker.sh ::::{tab-item} start_docker.sh
:selected: :selected:
:sync: pip :sync: uv pip
```{code-block} bash ```{code-block} bash
:substitutions:
#!/bin/bash #!/bin/bash
XPU_NUM=8 XPU_NUM=8
DOCKER_DEVICE_CONFIG="" DOCKER_DEVICE_CONFIG=""
@@ -31,7 +32,7 @@ if [ $XPU_NUM -gt 0 ]; then
done done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl" DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi fi
export build_image="wjie520/vllm_kunlun:base_v0.0.2" export build_image="wjie520/vllm_kunlun:uv_base"
# or export build_image="iregistry.baidu-int.com/xmlir/xmlir_ubuntu_2004_x86_64:v0.32" # or export build_image="iregistry.baidu-int.com/xmlir/xmlir_ubuntu_2004_x86_64:v0.32"
docker run -itd ${DOCKER_DEVICE_CONFIG} \ docker run -itd ${DOCKER_DEVICE_CONFIG} \
@@ -63,8 +64,71 @@ uv pip install -r requirements.txt
python setup.py build python setup.py build
python setup.py install python setup.py install
```
### Replace eval_frame.py
Copy the eval_frame.py patch:
```
cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/vllm_kunlun_0.10.1.1/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
```
## Choose to download customized xpytorch
### Install the KL3-customized build of PyTorch
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-02T05%3A01%3A27Z%2F-1%2Fhost%2Ff3cf499234f82303891aed2bcb0628918e379a21e841a3fac6bd94afef491ff7
(for the conda)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the uv)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
```
### Install the KL3-customized build of PyTorch (Only MIMO V2)
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1231/xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the conda)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the uv)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
``` ```
### Install the KL3-customized build of PyTorch (Only DeepSeek-V3.2-Exp-w8a8)
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://aihc-private-hcd.bj.bcebos.com/v1/vllm-kunlun-ds/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2026-02-03T01%3A59%3A40Z%2F-1%2Fhost%2Ffc4b6f5b83c2fde70d48fdfc23c40c396efc9cb3c36d6f811fdca5f109073321
(for the conda)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the uv)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
mv torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl torch_xray-2.0.3-cp310-cp310-linux_x86_64.whl && \
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g; s/torch_xray-999.9.9/torch_xray-2.0.3/' setup.sh && bash setup.sh
```
## Choose to download customized ops
### Install custom ops
```
uv pip install "https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xtorch_ops-0.1.2209%2B6752ad20-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-05T06%3A18%3A00Z%2F-1%2Fhost%2F14936c2b7e7c557c1400e4c467c79f7a9217374a7aa4a046711ac4d948f460cd"
```
### Install custom ops (Only MIMO V2)
```
uv pip install "https://vllm-ai-models.bj.bcebos.com/v1/vLLM-Kunlun/ops/swa/xtorch_ops-0.1.2109%252B523cb26d-cp310-cp310-linux_x86_64.whl"
```
### Install custom ops (Only DeepSeek-V3.2-Exp-w8a8)
```
uv pip install "https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1215/xtorch_ops-0.1.2263%2Bc030eebd-cp310-cp310-linux_x86_64.whl"
```
## Install the KLX3 custom Triton build
```
uv pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.0.0%2Bb2cde523-cp310-cp310-linux_x86_64.whl"
```
## Install the AIAK custom ops library
```
uv pip install "https://vllm-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20260130_152557/xspeedgate_ops-0.0.0%2Be5cdcbe-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKhvtgrTA8US5LIc8Vbl0mP%2F2026-01-30T10%3A33%3A32Z%2F2592000%2Fhost%2F3c13d67cc61d0df7538c198f5c32422f3b034068a40eef43cb51b079cc6f0555" --force-reinstall
```
## Quick Start ## Quick Start
### Set up the environment ### Set up the environment
@@ -81,7 +145,6 @@ chmod +x /workspace/vLLM-Kunlun/setup_env.sh && source /workspace/vLLM-Kunlun/se
:selected: :selected:
:sync: pip :sync: pip
```{code-block} bash ```{code-block} bash
:substitutions:
python -m vllm.entrypoints.openai.api_server \ python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 8356 \ --port 8356 \
@@ -112,41 +175,3 @@ python -m vllm.entrypoints.openai.api_server \
``` ```
:::: ::::
::::: :::::
### xpytorch and ops install
We also provide xpytorch and ops link for custom installation.
### Replace eval_frame.py
Copy the eval_frame.py patch:
```
cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/vllm_kunlun_0.10.1.1/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py
```
## Install the KL3-customized build of PyTorch
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-02T05%3A01%3A27Z%2F-1%2Fhost%2Ff3cf499234f82303891aed2bcb0628918e379a21e841a3fac6bd94afef491ff7
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
```
## Install the KL3-customized build of PyTorch(Only MIMO V2)
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1231/xpytorch-cp310-torch251-ubuntu2004-x64.run
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
```
## Install custom ops
```
pip install "https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xtorch_ops-0.1.2209%2B6752ad20-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-05T06%3A18%3A00Z%2F-1%2Fhost%2F14936c2b7e7c557c1400e4c467c79f7a9217374a7aa4a046711ac4d948f460cd"
```
## Install custom ops(Only MIMO V2)
```
pip install "https://vllm-ai-models.bj.bcebos.com/v1/vLLM-Kunlun/ops/swa/xtorch_ops-0.1.2109%252B523cb26d-cp310-cp310-linux_x86_64.whl"
```
## Install the KLX3 custom Triton build
```
pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.0.0%2Bb2cde523-cp310-cp310-linux_x86_64.whl"
```
## Install the AIAK custom ops library
```
pip install "https://cce-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20251219_152418/xspeedgate_ops-0.0.0-cp310-cp310-linux_x86_64.whl"
```

View File

@@ -0,0 +1,140 @@
# Multi XPU (DeepSeek-V3.2-Exp-w8a8)
## Run vllm-kunlun on Multi XPU
Setup environment using container:
Please follow the [installation.md](../installation.md) document to set up the environment first.
Create a container
```bash
# !/bin/bash
# rundocker.sh
XPU_NUM=8
DOCKER_DEVICE_CONFIG=""
if [ $XPU_NUM -gt 0 ]; then
for idx in $(seq 0 $((XPU_NUM-1))); do
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi
export build_image="xxx"
docker run -itd ${DOCKER_DEVICE_CONFIG} \
--net=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
--cap-add=SYS_PTRACE \
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
--name "$1" \
-w /workspace \
"$build_image" /bin/bash
```
### Preparation Weight
- Pull DeepSeek-V3.2-Exp-w8a8-int8 weights
```
wget -O DeepSeek-V3.2-Exp-w8a8-int8.tar.gz https://aihc-private-hcd.bj.bcebos.com/v1/LLM/DeepSeek/DeepSeek-V3.2-Exp-w8a8-int8.tar.gz?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2025-12-24T06%3A07%3A10Z%2F-1%2Fhost%2Fa324bf469176934a05f75d3acabc3c1fb891be150f43fb1976e65b7ec68733db
```
- Ensure that the field "quantization_config" is included.If not, deployment will result in an OOM (Out of Memory) error.
vim model/DeepSeek-V3.2-Exp-w8a8-int8/config.json
```config.json
"quantization_config": {
"config_groups": {
"group_0": {
"format": "int-quantized",
"input_activations": {
"actorder": null,
"block_structure": null,
"dynamic": true,
"group_size": null,
"num_bits": 8,
"observer": null,
"observer_kwargs": {},
"strategy": "token",
"symmetric": true,
"type": "int"
},
"output_activations": null,
"targets": [
"Linear"
],
"weights": {
"actorder": null,
"block_structure": null,
"dynamic": false,
"group_size": null,
"num_bits": 8,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "channel",
"symmetric": true,
"type": "int"
}
}
},
"format": "int-quantized",
"global_compression_ratio": null,
"ignore": [
"lm_head"
],
"kv_cache_scheme": null,
"quant_method": "compressed-tensors",
"quantization_status": "compressed",
"sparsity_config": {},
"transform_config": {},
"version": "0.12.2"
},
```
### Online Serving on Multi XPU
Start the vLLM server on multi XPU:
```bash
unset XPU_DUMMY_EVENT && \
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && \
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && \
export XMLIR_CUDNN_ENABLED=1 && \
export XPU_USE_DEFAULT_CTX=1 && \
export XMLIR_FORCE_USE_XPU_GRAPH=1 && \
export XMLIR_ENABLE_FAST_FC=1 && \
export XPU_USE_FAST_SWIGLU=1 && \
export CUDA_GRAPH_OPTIMIZE_STREAM=1 && \
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false && \
export XPU_USE_MOE_SORTED_THRES=1 && \
export USE_ORI_ROPE=1 && \
export VLLM_USE_V1=1
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8806 \
--model /data/DeepSeek-V3.2-Exp-w8a8-int8 \
--gpu-memory-utilization 0.95 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 8 \
--dtype float16 \
--max_num_seqs 32 \
--max_num_batched_tokens 8192 \
--block-size 64 \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--disable-log-requests \
--no-enable-prefix-caching --kv-cache-dtype bfloat16 \
--compilation-config '{"splitting_ops":["vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_attention_with_output_kunlun",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer",
"vllm.sparse_attn_indexer_vllm_kunlun"]}'
```

View File

@@ -34,7 +34,9 @@ static inline std::string get_shm_name() {
} }
static constexpr uint32_t heartbeat_us = 1000; // microseconds static constexpr uint32_t heartbeat_us = 1000; // microseconds
static constexpr uint32_t heartbeat_timeout_us = 20 * heartbeat_us; static constexpr uint32_t heartbeat_check_everyN = 50;
static constexpr uint32_t heartbeat_timeout_us =
heartbeat_check_everyN * heartbeat_us;
struct alignas(64) WorkerHeartBeat { struct alignas(64) WorkerHeartBeat {
std::atomic<uint64_t> timestamp; std::atomic<uint64_t> timestamp;

View File

@@ -51,17 +51,16 @@ void ShmManager::set_xpu_info(int device_id, uint32_t xpu_pci_addr,
void ShmManager::run_busy_loop() { void ShmManager::run_busy_loop() {
spdlog::info("ShmManager busy loop started"); spdlog::info("ShmManager busy loop started");
int heart_beat_check_everyN = 20;
int loop_cnt = 0; int loop_cnt = 0;
while (!stop_loop_flag.load(std::memory_order_acquire)) { while (!stop_loop_flag.load(std::memory_order_acquire)) {
process_requests(); process_requests();
if (loop_cnt % heart_beat_check_everyN == 0) { if (loop_cnt % heartbeat_check_everyN== 0) {
check_heart_beats(); check_heart_beats();
} }
loop_cnt = (loop_cnt + 1) % heartbeat_check_everyN;
loop_cnt = (loop_cnt + 1) % heart_beat_check_everyN;
usleep(heartbeat_us); usleep(heartbeat_us);
} }

View File

@@ -1,11 +1,6 @@
"""kunlun_ops for lora""" """kunlun_ops for lora"""
import torch import torch
import xspeedgate_ops
import time
from torch._C import dtype
import os
from torch._dynamo import disable
def sgmv_shrink( def sgmv_shrink(
@@ -27,13 +22,18 @@ def sgmv_shrink(
""" """
sgmv_shrink sgmv_shrink
""" """
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
inputs,
lora_a_weights,
seq_len_tensor.to(torch.int32),
lora_indices_tensor.to(torch.int32),
output_tensor,
scaling,
)
return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling) def sgmv_expand(
inputs: torch.Tensor,
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
block_statistic: torch.Tensor, block_statistic: torch.Tensor,
@@ -45,16 +45,23 @@ def sgmv_expand(inputs: torch.Tensor,
batches: int, batches: int,
max_seq_length: int, max_seq_length: int,
token_nums: int, token_nums: int,
add_inputs: bool = False): add_inputs: bool = False,
):
""" """
sgmv_expand sgmv_expand
""" """
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0) inputs,
lora_b_weights,
seq_len_tensor.to(torch.int32),
lora_indices_tensor.to(torch.int32),
output_tensor,
0,
)
def sgmv_expand_slice(
def sgmv_expand_slice(inputs: torch.Tensor, inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
block_statistic: torch.Tensor, block_statistic: torch.Tensor,
@@ -69,18 +76,19 @@ def sgmv_expand_slice(inputs: torch.Tensor,
token_nums: int, token_nums: int,
slice_offset: int, slice_offset: int,
slice_size: int, slice_size: int,
add_inputs: bool = False): add_inputs: bool = False,
):
""" """
sgmv_expand_slice sgmv_expand_slice
""" """
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
inputs,
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, slice_offset) lora_b_weights,
seq_len_tensor.to(torch.int32),
lora_indices_tensor.to(torch.int32),
output_tensor,
slice_offset,
)
def bgmv_shrink( def bgmv_shrink(
@@ -92,27 +100,33 @@ def bgmv_shrink(
moe_index: torch.Tensor, moe_index: torch.Tensor,
expert_m: torch.Tensor, expert_m: torch.Tensor,
lora_indices_tensor: torch.Tensor, # [m] lora_indices_tensor: torch.Tensor, # [m]
scaling: float = 1.0 scaling: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
bgmv_shrink bgmv_shrink
""" """
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling) return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
)
def bgmv_expand(inputs: torch.Tensor, def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor, lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
block_statistic: torch.Tensor, block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor, sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor, moe_index: torch.Tensor,
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
add_inputs: bool = True): add_inputs: bool = True,
"""" ):
""" "
bgmv_expand bgmv_expand
""" """
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0) return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
# @my_wrapper inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
)
def bgmv_expand_slice( def bgmv_expand_slice(
inputs: torch.Tensor, inputs: torch.Tensor,
@@ -125,9 +139,11 @@ def bgmv_expand_slice(
lora_indices_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor,
slice_offset: int, slice_offset: int,
slice_size: int, slice_size: int,
add_inputs: bool = True add_inputs: bool = True,
): ):
""" """
bgmv_expand_slice bgmv_expand_slice
""" """
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset) return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
)

View File

@@ -22,16 +22,11 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import TYPE_CHECKING, Optional, Union, final
import torch
# Disable torchdynamo for all functions in this file
torch._dynamo.config.disable = True
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
import torch
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm_kunlun.lora.ops.kunlun_ops import ( from vllm_kunlun.lora.ops.kunlun_ops import (
bgmv_expand, bgmv_expand,
@@ -42,7 +37,7 @@ from vllm_kunlun.lora.ops.kunlun_ops import (
sgmv_shrink, sgmv_shrink,
) )
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase # Disable torchdynamo for all functions in this file
# The platforms that are compatible with the PyTorch-native implementation can # The platforms that are compatible with the PyTorch-native implementation can

View File

@@ -14,39 +14,53 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-kunlun project. # This file is a part of the vllm-kunlun project.
# #
from vllm.config import VllmConfig, get_layers_from_vllm_config
import xtorch_ops
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, ClassVar, Tuple, Type, TYPE_CHECKING from itertools import accumulate
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
)
import torch
import numpy as np import numpy as np
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, import torch
AttentionMetadata, AttentionLayer, AttentionType) import xtorch_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionType,
)
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
CommonAttentionMetadata,
split_decodes_and_prefills,
)
# from vllm.attention.backends.utils import CommonAttentionState # from vllm.attention.backends.utils import CommonAttentionState
# from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping # from vllm.attention.backends.utils import is_block_tables_empty, compute_slot_mapping_start_idx, compute_slot_mapping
from vllm_kunlun.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) from vllm_kunlun.ops.paged_attn import PagedAttention, PagedAttentionMetadata
from vllm_kunlun.ops._kunlun_ops import KunlunOps
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
AttentionCGSupport,
split_decodes_and_prefills)
from vllm.forward_context import ForwardContext, get_forward_context
from itertools import accumulate
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
import inspect
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig, get_layers_from_vllm_config
import inspect
class KunlunAttentionBackend(AttentionBackend): class KunlunAttentionBackend(AttentionBackend):
"""KunlunAttentionBackend""" """KunlunAttentionBackend"""
# crucial to cuda graph # crucial to cuda graph
accept_output_buffer = True accept_output_buffer = True
@@ -81,12 +95,13 @@ class KunlunAttentionBackend(AttentionBackend):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto" cache_dtype_str: str = "auto",
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
"""get_kv_cache_shape""" """get_kv_cache_shape"""
# return (2, num_blocks, block_size, num_kv_heads * head_size) # return (2, num_blocks, block_size, num_kv_heads * head_size)
return PagedAttention.get_kv_cache_shape(num_blocks, block_size, return PagedAttention.get_kv_cache_shape(
num_kv_heads, head_size) num_blocks, block_size, num_kv_heads, head_size
)
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
@@ -110,7 +125,6 @@ class KunlunAttentionBackend(AttentionBackend):
class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata): class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
"""KunlunMetadata""" """KunlunMetadata"""
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---| # |- tokenA -|......................|-- newTokens ---|
@@ -208,6 +222,8 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
num_decode_tokens: int = 0 num_decode_tokens: int = 0
num_prefills: int = 0 num_prefills: int = 0
num_decodes: int = 0 num_decodes: int = 0
is_speculative: Optional[bool] = False
max_model_len: int = 0
def __post_init__(self): def __post_init__(self):
"""__post_init__""" """__post_init__"""
@@ -218,16 +234,20 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
@property @property
def is_all_encoder_attn_metadata_set(self): def is_all_encoder_attn_metadata_set(self):
"""is_all_encoder_attn_metadata_set""" """is_all_encoder_attn_metadata_set"""
return ((self.encoder_seq_lens is not None) return (
(self.encoder_seq_lens is not None)
and (self.encoder_seq_lens_tensor is not None) and (self.encoder_seq_lens_tensor is not None)
and (self.max_encoder_seq_len is not None)) and (self.max_encoder_seq_len is not None)
)
@property @property
def is_all_cross_attn_metadata_set(self): def is_all_cross_attn_metadata_set(self):
"""is_all_cross_attn_metadata_set""" """is_all_cross_attn_metadata_set"""
return (self.is_all_encoder_attn_metadata_set return (
self.is_all_encoder_attn_metadata_set
and (self.cross_slot_mapping is not None) and (self.cross_slot_mapping is not None)
and (self.cross_block_tables is not None)) and (self.cross_block_tables is not None)
)
@property @property
def prefill_metadata(self) -> Optional["KunlunMetadata"]: def prefill_metadata(self) -> Optional["KunlunMetadata"]:
@@ -240,34 +260,59 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# metadata structure # metadata structure
return self._cached_prefill_metadata return self._cached_prefill_metadata
assert ((self.seq_lens_tensor is not None) assert (self.seq_lens_tensor is not None) or (
or (self.encoder_seq_lens_tensor is not None)) self.encoder_seq_lens_tensor is not None
)
# Compute some attn_metadata fields which default to None # Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else query_start_loc = (
self.query_start_loc[-(self.num_prefills + 1):] - self.query_start_loc[-(self.num_prefills + 1)]) None
if self.query_start_loc is None
else self.query_start_loc[-(self.num_prefills + 1) :]
- self.query_start_loc[-(self.num_prefills + 1)]
)
# flash attention needs both lod information on host and device # flash attention needs both lod information on host and device
query_start_loc_host = (None if self.query_start_loc_host is None else query_start_loc_host = (
self.query_start_loc_host[-(self.num_prefills + 1):] - self.query_start_loc_host[-(self.num_prefills + 1)]) None
if self.query_start_loc_host is None
else self.query_start_loc_host[-(self.num_prefills + 1) :]
- self.query_start_loc_host[-(self.num_prefills + 1)]
)
# TODO(chengruichang):how to support prefix cache # TODO(chengruichang):how to support prefix cache
kv_prefix_start_loc_host = None kv_prefix_start_loc_host = None
kv_prefix_start_loc = None kv_prefix_start_loc = None
slot_mapping = (None if self.slot_mapping is None else slot_mapping = (
self.slot_mapping[-self.num_prefill_tokens:]) None
if self.slot_mapping is None
else self.slot_mapping[-self.num_prefill_tokens :]
)
seq_lens_tensor = (None if self.seq_lens_tensor is None else seq_lens_tensor = (
self.seq_lens_tensor[-self.num_prefills:]) None
seq_lens = (None if self.seq_lens is None else self.seq_lens[-self.num_prefills:]) if self.seq_lens_tensor is None
else self.seq_lens_tensor[-self.num_prefills :]
)
seq_lens = (
None if self.seq_lens is None else self.seq_lens[-self.num_prefills :]
)
context_lens_tensor = (None if self.context_lens_tensor is None else context_lens_tensor = (
self.context_lens_tensor[-self.num_prefills:]) None
if self.context_lens_tensor is None
block_tables = (None if self.block_tables is None else else self.context_lens_tensor[-self.num_prefills :]
self.block_tables[-self.num_prefills:]) )
input_positions = (None if self.input_positions is None else
self.input_positions[-self.num_prefills:])
block_tables = (
None
if self.block_tables is None
else self.block_tables[-self.num_prefills :]
)
input_positions = (
None
if self.input_positions is None
else self.input_positions[-self.num_prefills :]
)
if self.kv_lod_cpu is None: if self.kv_lod_cpu is None:
kv_lod_cpu = None kv_lod_cpu = None
@@ -280,19 +325,17 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
base_xpu = self.kv_lod_xpu[start] base_xpu = self.kv_lod_xpu[start]
kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu kv_lod_xpu = self.kv_lod_xpu[start:] - base_xpu
# Construct & cache prefill-phase attention metadata structure # Construct & cache prefill-phase attention metadata structure
self._cached_prefill_metadata = KunlunMetadata( self._cached_prefill_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens, num_actual_tokens=self.num_actual_tokens,
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps,
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0, num_decode_tokens=0,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
seq_start_loc = None, seq_start_loc=None,
kv_lod_cpu=kv_lod_cpu, kv_lod_cpu=kv_lod_cpu,
kv_lod_xpu=kv_lod_xpu, kv_lod_xpu=kv_lod_xpu,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
@@ -314,7 +357,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables, cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False, enable_kv_scales_calculation=False,
use_cascade=self.use_cascade) use_cascade=self.use_cascade,
is_speculative=self.is_speculative,
)
return self._cached_prefill_metadata return self._cached_prefill_metadata
@property @property
@@ -327,40 +372,47 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
# Recover cached decode-phase attention # Recover cached decode-phase attention
# metadata structure # metadata structure
return self._cached_decode_metadata return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None) assert (self.seq_lens_tensor is not None) or (
or (self.encoder_seq_lens_tensor is not None)) self.encoder_seq_lens_tensor is not None
)
if self.num_prefills != 0: if self.num_prefills != 0:
# Compute some attn_metadata fields which default to None # Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else slot_mapping = (
self.slot_mapping[:-self.num_prefill_tokens]) None
seq_lens_tensor = (None if self.seq_lens_tensor is None else if self.slot_mapping is None
self.seq_lens_tensor[:-self.num_prefills]) else self.slot_mapping[: -self.num_prefill_tokens]
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else )
self.seq_lens_tensor_cpu[:-self.num_prefills]) seq_lens_tensor = (
None
block_tables = (None if self.block_tables is None else if self.seq_lens_tensor is None
self.block_tables[:-self.num_prefills]) else self.seq_lens_tensor[: -self.num_prefills]
)
seq_lens_tensor_cpu = (
None
if self.seq_lens_tensor_cpu is None
else self.seq_lens_tensor_cpu[: -self.num_prefills]
)
block_tables = (
None
if self.block_tables is None
else self.block_tables[: -self.num_prefills]
)
else: else:
# Compute some attn_metadata fields which default to None # Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else slot_mapping = None if self.slot_mapping is None else self.slot_mapping
self.slot_mapping) seq_lens_tensor = (
seq_lens_tensor = (None if self.seq_lens_tensor is None else None if self.seq_lens_tensor is None else self.seq_lens_tensor
self.seq_lens_tensor) )
seq_lens_tensor_cpu = (
seq_lens_tensor_cpu = (None if self.seq_lens_tensor_cpu is None else None if self.seq_lens_tensor_cpu is None else self.seq_lens_tensor_cpu
self.seq_lens_tensor_cpu) )
block_tables = None if self.block_tables is None else self.block_tables
block_tables = (None if self.block_tables is None else
self.block_tables)
# Construct & cache decode-phase attention metadata structure # Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = KunlunMetadata( self._cached_decode_metadata = KunlunMetadata(
num_actual_tokens=self.num_actual_tokens, num_actual_tokens=self.num_actual_tokens,
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.multi_modal_placeholder_index_maps,
multi_modal_placeholder_index_maps,
num_prefills=0, num_prefills=0,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
@@ -378,19 +430,29 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
cross_slot_mapping=self.cross_slot_mapping, cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables, cross_block_tables=self.cross_block_tables,
enable_kv_scales_calculation=False, enable_kv_scales_calculation=False,
use_cascade=self.use_cascade) use_cascade=self.use_cascade,
is_speculative=self.is_speculative,
max_model_len=self.max_model_len,
)
return self._cached_decode_metadata return self._cached_decode_metadata
M = TypeVar("M")
class KunlunAttentionMetadataBuilder: class KunlunAttentionMetadataBuilder:
"""KunlunAttentionMetadataBuilder""" """KunlunAttentionMetadataBuilder"""
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[Optional[int]] = 1 reorder_batch_threshold: ClassVar[Optional[int]] = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(
vllm_config: VllmConfig, device: torch.device): self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
"""__init__""" """__init__"""
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
@@ -398,17 +460,45 @@ class KunlunAttentionMetadataBuilder:
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.num_heads_q = self.model_config.get_num_attention_heads( self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config) self.parallel_config
self.num_heads_kv = self.model_config.get_num_kv_heads( )
self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
self.headdim = self.model_config.get_head_size() self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.device = device self.device = device
def reorder_batch(self, input_batch: "InputBatch", def _init_reorder_batch_threshold(
scheduler_output: "SchedulerOutput") -> bool: self,
reorder_batch_threshold: int | None = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (
speculative_config is not None
and speculative_config.num_speculative_tokens is not None
):
self.reorder_batch_threshold = max(
self.reorder_batch_threshold,
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
"""reorder_batch""" """reorder_batch"""
decodes = [] decodes = []
prefills = [] prefills = []
@@ -432,8 +522,9 @@ class KunlunAttentionMetadataBuilder:
for i in range(1, min(num_decodes, num_prefills) + 1): for i in range(1, min(num_decodes, num_prefills) + 1):
if decodes[num_decodes - i] >= num_decodes: if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill], input_batch.swap_states(
decodes[num_decodes - i]) prefills[first_prefill], decodes[num_decodes - i]
)
first_prefill += 1 first_prefill += 1
modified_batch = True modified_batch = True
else: else:
@@ -454,8 +545,30 @@ class KunlunAttentionMetadataBuilder:
attn_metadata.seq_lens_tensor.fill_(1) attn_metadata.seq_lens_tensor.fill_(1)
return attn_metadata return attn_metadata
def build(self, common_prefix_len: int, def build_for_drafting(
common_attn_metadata: CommonAttentionMetadata): self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
def build(
self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata
):
"""build""" """build"""
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
@@ -464,30 +577,38 @@ class KunlunAttentionMetadataBuilder:
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[:num_reqs + 1].to( query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
self.device, non_blocking=True) self.device, non_blocking=True
)
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_start_loc = list(accumulate(seq_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0))
seq_start_loc_tensor = torch.empty(
len(seq_start_loc), dtype=torch.int32, device=self.device
seq_start_loc_tensor = torch.empty(len(seq_start_loc), dtype=torch.int32, device=self.device) )
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32)) seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu") kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0) kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
kv_lod_xpu = kv_lod_cpu.to(self.device) kv_lod_xpu = kv_lod_cpu.to(self.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
split_decodes_and_prefills(common_attn_metadata) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold or 1,
require_uniform=True,
)
)
num_scheduled_tokens = np.diff(common_attn_metadata.query_start_loc_cpu[:num_reqs + 1]) num_scheduled_tokens = np.diff(
common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
)
tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes] tmp_decode_scheduled_tokens = num_scheduled_tokens[:num_decodes]
if num_decode_tokens == 0: if num_decode_tokens == 0:
@@ -495,7 +616,7 @@ class KunlunAttentionMetadataBuilder:
else: else:
max_decode_seq_len = np.max(tmp_decode_scheduled_tokens) max_decode_seq_len = np.max(tmp_decode_scheduled_tokens)
tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes: num_reqs] tmp_prefill_scheduled_tokens = num_scheduled_tokens[num_decodes:num_reqs]
if num_prefill_tokens == 0: if num_prefill_tokens == 0:
max_prefill_seq_len = 0 max_prefill_seq_len = 0
@@ -507,6 +628,7 @@ class KunlunAttentionMetadataBuilder:
attn_metadata = KunlunMetadata( attn_metadata = KunlunMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
num_prefills=num_prefills, num_prefills=num_prefills,
num_decodes=num_decodes,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True, enable_kv_scales_calculation=True,
@@ -525,11 +647,14 @@ class KunlunAttentionMetadataBuilder:
block_tables=block_table_tensor, block_tables=block_table_tensor,
use_cuda_graph=False, use_cuda_graph=False,
use_cascade=use_cascade, use_cascade=use_cascade,
is_speculative=self.reorder_batch_threshold > 1,
max_model_len=self.vllm_config.model_config.max_model_len,
) )
return attn_metadata return attn_metadata
def can_run_in_cudagraph( def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool: self, common_attn_metadata: CommonAttentionMetadata
) -> bool:
"""can_run_in_cudagraph""" """can_run_in_cudagraph"""
# Full CUDA Graph always supported (FA2 support checked separately) # Full CUDA Graph always supported (FA2 support checked separately)
return True return True
@@ -538,6 +663,7 @@ class KunlunAttentionMetadataBuilder:
"""use_cascade_attention""" """use_cascade_attention"""
return use_cascade_attention(*args, **kwargs) return use_cascade_attention(*args, **kwargs)
class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]): class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
"""KunlunAttentionImpl""" """KunlunAttentionImpl"""
@@ -555,13 +681,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False, use_irope: bool = False,
sinks:Optional[torch.Tensor]= None, sinks: Optional[torch.Tensor] = None,
multi_modal_placeholder_index_maps:Optional[torch.Tensor]= None, multi_modal_placeholder_index_maps: Optional[torch.Tensor] = None,
) -> None: ) -> None:
"""__init__""" """__init__"""
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError( raise ValueError("kunlunAttention does not support block-sparse attention.")
"kunlunAttention does not support block-sparse attention.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
@@ -582,14 +707,16 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if head_size not in suppored_head_sizes: if head_size not in suppored_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. " f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.") f"Supported head sizes are: {suppored_head_sizes}."
)
self.sinks = sinks self.sinks = sinks
if sinks is not None: if sinks is not None:
assert sinks.shape[0] == num_heads, ( assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of " "Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, " f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.") f"num_heads: {num_heads}."
)
self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps self.multi_modal_placeholder_index_maps = multi_modal_placeholder_index_maps
def forward( def forward(
@@ -605,7 +732,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""forward""" """forward"""
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
@@ -624,7 +751,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Self-attention vs. cross-attention will impact # Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which # which KV cache memory-mapping & which
# seqlen datastructures we utilize # seqlen datastructures we utilize
if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
# KV-cache during decoder-self- or # KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not # encoder-decoder-cross-attention, but not
# during encoder attention. # during encoder attention.
@@ -633,7 +760,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# we still need to break out key_cache and value_cache # we still need to break out key_cache and value_cache
# i.e. for later use by paged attention # i.e. for later use by paged attention
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size) kv_cache, self.num_kv_heads, self.head_size
)
if (key is not None) and (value is not None): if (key is not None) and (value is not None):
updated_slot_mapping = attn_metadata.slot_mapping updated_slot_mapping = attn_metadata.slot_mapping
@@ -644,11 +772,12 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
value = value.contiguous() value = value.contiguous()
if key_cache.is_contiguous(): if key_cache.is_contiguous():
xtorch_ops.reshape_and_cache( xtorch_ops.reshape_and_cache(
key, key[: attn_metadata.num_actual_tokens],
value, value[: attn_metadata.num_actual_tokens],
key_cache, key_cache,
value_cache, value_cache,
updated_slot_mapping) updated_slot_mapping,
)
else: else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2) cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2) cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
@@ -657,7 +786,8 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
value, value,
cast_key_cache, cast_key_cache,
cast_value_cache, cast_value_cache,
updated_slot_mapping) updated_slot_mapping,
)
assert attn_type == AttentionType.DECODER assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill. # Decoder self-attention supports chunked prefill.
@@ -668,9 +798,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
prefill_query = query[num_decode_tokens:attn_metadata.num_actual_tokens] prefill_query = query[num_decode_tokens : attn_metadata.num_actual_tokens]
prefill_key = key[num_decode_tokens:attn_metadata.num_actual_tokens] prefill_key = key[num_decode_tokens : attn_metadata.num_actual_tokens]
prefill_value = value[num_decode_tokens:attn_metadata.num_actual_tokens] prefill_value = value[num_decode_tokens : attn_metadata.num_actual_tokens]
# For hybrid Attention (Qwen3-Next.) # For hybrid Attention (Qwen3-Next.)
if key_cache.is_contiguous(): if key_cache.is_contiguous():
@@ -685,7 +815,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
q=prefill_query, q=prefill_query,
k=key_cache, # Key Cache [block_num, head, block_size, dim] k=key_cache, # Key Cache [block_num, head, block_size, dim]
v=value_cache, v=value_cache,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens], out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
is_causal=True, is_causal=True,
is_prefix_cache=True, is_prefix_cache=True,
block_table=tmp_block_tables, block_table=tmp_block_tables,
@@ -694,35 +824,41 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu, context_kvlen_lod_cpu=prefill_meta.kv_lod_cpu,
context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu, context_kvlen_lod_xpu=prefill_meta.kv_lod_xpu,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softmax_lse=None softmax_lse=None,
) )
else: else:
xtorch_ops.prefill_attention( xtorch_ops.prefill_attention(
q=prefill_query, q=prefill_query,
k=prefill_key, k=prefill_key,
v=prefill_value, v=prefill_value,
out=output[num_decode_tokens:attn_metadata.num_actual_tokens], out=output[num_decode_tokens : attn_metadata.num_actual_tokens],
is_causal=True, is_causal=True,
context_qlen_lod_cpu=prefill_meta.query_start_loc_host, context_qlen_lod_cpu=prefill_meta.query_start_loc_host,
context_qlen_lod_xpu=prefill_meta.query_start_loc, context_qlen_lod_xpu=prefill_meta.query_start_loc,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softmax_lse=None, softmax_lse=None,
swa_left = self.sliding_window if self.sliding_window is not None else -1, swa_left=(
swa_right = 0 if self.sliding_window is not None else -1, self.sliding_window if self.sliding_window is not None else -1
sink = self.sinks.to(torch.float32) if self.sinks is not None else None ),
swa_right=0 if self.sliding_window is not None else -1,
sink=(
self.sinks.to(torch.float32) if self.sinks is not None else None
),
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, ( assert (
"Encoder-only models should not have decode metadata.") attn_type != AttentionType.ENCODER_ONLY
), "Encoder-only models should not have decode metadata."
decode_query = query[:num_decode_tokens] decode_query = query[:num_decode_tokens]
# For hybrid Attention (Qwen3-Next # For hybrid Attention (Qwen3-Next
if key_cache.is_contiguous(): if key_cache.is_contiguous():
tmp_block_tables = decode_meta.block_tables tmp_block_tables = decode_meta.block_tables
else: else:
tmp_block_tables = decode_meta.block_tables * 2 # only test in Qwen3-Next tmp_block_tables = (
decode_meta.block_tables * 2
) # only test in Qwen3-Next
sig = inspect.signature(xtorch_ops.speculative_attention) sig = inspect.signature(xtorch_ops.speculative_attention)
if "max_window_size" in sig.parameters: if "max_window_size" in sig.parameters:
@@ -745,11 +881,15 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
kv_head_num=self.num_kv_heads, kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2], block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1], max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
max_window_size=self.sliding_window if self.sliding_window is not None else -1, max_window_size=(
self.sliding_window if self.sliding_window is not None else -1
),
block_tables=tmp_block_tables, block_tables=tmp_block_tables,
sink = self.sinks.to(torch.float32) if self.sinks is not None else None sink=(
self.sinks.to(torch.float32) if self.sinks is not None else None
),
) )
else: elif not attn_metadata.is_speculative:
xtorch_ops.paged_attention( xtorch_ops.paged_attention(
x=decode_query, x=decode_query,
k_cache=key_cache, k_cache=key_cache,
@@ -760,10 +900,38 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
is_context=False, is_context=False,
is_causal=True, is_causal=True,
out=output[:num_decode_tokens], out=output[:num_decode_tokens],
vo_head_dim=self.head_size vo_head_dim=self.head_size,
)
else:
batch_size = attn_metadata.num_decodes
query_seq_len, head_num, head_dim = decode_query.shape
assert query_seq_len % batch_size == 0
qlen = query_seq_len // batch_size
out = output[:num_decode_tokens]
assert out.is_contiguous()
xtorch_ops.speculative_attention(
out=out.view(batch_size, qlen, head_num, self.head_size),
q=decode_query.view(batch_size, qlen, head_num, head_dim),
k_cache=key_cache,
v_cache=value_cache,
context_lens_cpu=decode_meta.seq_lens_tensor_cpu,
context_lens_xpu=decode_meta.seq_lens_tensor,
batch_num=batch_size,
qlen=qlen,
max_context_len=decode_meta.max_model_len,
head_num=self.num_heads,
head_dim=self.head_size,
scale=0.0,
kv_head_num=self.num_kv_heads,
block_size=key_cache.shape[2],
max_num_blocks_per_seq=decode_meta.block_tables.shape[1],
block_tables=tmp_block_tables,
) )
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
def use_cascade_attention( def use_cascade_attention(
common_prefix_len: int, common_prefix_len: int,
query_lens: np.ndarray, query_lens: np.ndarray,
@@ -803,8 +971,12 @@ def use_cascade_attention(
num_queries_per_kv = num_query_heads // num_kv_heads num_queries_per_kv = num_query_heads // num_kv_heads
# The criteria for using FlashDecoding can be found in the following link: # The criteria for using FlashDecoding can be found in the following link:
# https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535 # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window use_flash_decoding = (
and not use_alibi and np.all(query_lens == 1)) num_queries_per_kv > 1
and not use_sliding_window
and not use_alibi
and np.all(query_lens == 1)
)
if not use_flash_decoding: if not use_flash_decoding:
# Use cascade attention. # Use cascade attention.
return True return True
@@ -826,8 +998,9 @@ def use_cascade_attention(
cascade_waves = cdiv(cascade_ctas, num_sms) cascade_waves = cdiv(cascade_ctas, num_sms)
cascade_time = cascade_waves * num_prefix_tiles cascade_time = cascade_waves * num_prefix_tiles
flash_decoding_ctas = (num_reqs * num_kv_heads * flash_decoding_ctas = (
cdiv(num_queries_per_kv, q_tile_size)) num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
)
flash_decoding_ctas *= num_prefix_tiles flash_decoding_ctas *= num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms) flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

View File

@@ -1,16 +1,24 @@
"""vllm_utils_wrapper.py""" """vllm_utils_wrapper.py"""
import vllm.distributed.parallel_state as parallel_state
import vllm.utils as _orig
from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple
from types import SimpleNamespace
import torch
from torch.library import Library
import inspect import inspect
import socket
import typing import typing
from torch.library import register_fake from types import SimpleNamespace
import vllm_kunlun._kunlun from typing import Any, Callable, List, Optional, Tuple, Union, get_args, get_origin
import torch
import vllm.distributed.parallel_state as parallel_state
import vllm.envs as envs import vllm.envs as envs
import vllm.utils as _orig
from torch.library import Library, register_fake
try:
import vllm_kunlun._kunlun # noqa: F401
except ImportError as e:
try:
from . import _kunlun # noqa: F401, F403
except ImportError:
print(f"Warning: Failed to load vllm_kunlun native extension: {e}")
def patch_annotations_for_schema(func): def patch_annotations_for_schema(func):
@@ -87,7 +95,7 @@ def direct_register_custom_op(
import torch.library import torch.library
if hasattr(torch.library, "infer_schema"): if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func) patch_annotations_for_schema(op_func)
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else: else:
# for pytorch 2.4 # for pytorch 2.4
@@ -153,7 +161,7 @@ _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors _wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
_wrapped._get_open_port = _get_open_port _wrapped._get_open_port = _get_open_port
import sys import sys # noqa: E402
sys.modules["vllm.utils"] = _wrapped sys.modules["vllm.utils"] = _wrapped
@@ -204,11 +212,10 @@ parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
from torch.library import custom_op, impl from typing import Optional # noqa: E402
import torch
from vllm import _custom_ops as ops import torch # noqa: E402
from typing import Optional, List from torch.library import custom_op, impl # noqa: E402
import os
@custom_op("_C::rms_norm", mutates_args=()) @custom_op("_C::rms_norm", mutates_args=())
@@ -379,9 +386,9 @@ def silu_and_mul_quant_xpu(
pass pass
import torch import torch # noqa: E402
import xtorch_ops import xtorch_ops # noqa: E402
from torch.library import custom_op, impl from torch.library import custom_op, impl # noqa: E402
@custom_op("_C::add_rmsnorm", mutates_args=()) @custom_op("_C::add_rmsnorm", mutates_args=())
@@ -472,7 +479,7 @@ def rmsnorm_cuda(
) )
import torch import torch # noqa: E402
def _fake_rmsnorm( def _fake_rmsnorm(
@@ -618,7 +625,6 @@ split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox)
# register fake op impl here # register fake op impl here
# for torch.dynamo # for torch.dynamo
from torch.library import register_fake
if hasattr(torch.ops.custom_ops, "fc_fusion"): if hasattr(torch.ops.custom_ops, "fc_fusion"):
@@ -1396,7 +1402,7 @@ def awq_dequantize_cuda(
device=qweight.device, device=qweight.device,
) )
group_m = int(qweight.shape[0] / scales.shape[0]) group_m = int(qweight.shape[0] / scales.shape[0])
out = xtorch_ops.awq_dequantize( xtorch_ops.awq_dequantize(
qweight=qweight, qweight=qweight,
scales=scales, scales=scales,
zeros=zeros, zeros=zeros,
@@ -1915,7 +1921,7 @@ def apply_repetition_penalties_(
@impl("_C::apply_repetition_penalties_", "CUDA") @impl("_C::apply_repetition_penalties_", "CUDA")
def apply_repetition_penalties_( def apply_repetition_penalties_cuda(
logits: torch.Tensor, logits: torch.Tensor,
prompt_mask: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor, output_mask: torch.Tensor,
@@ -2341,34 +2347,499 @@ dequant_int4.register_fake(_fake_dequant_int4)
################################################## ##################################################
@custom_op("_C::fast_topkv2", mutates_args=()) @custom_op("_C::fast_topkv2", mutates_args=())
def fast_topkv2( def fast_topkv2(
score: torch.Tensor, score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
lengths: torch.Tensor, ) -> torch.Tensor:
topk: Optional[int] = 2048) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = xtorch_ops.fast_topkv2( topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
score=score,
lengths=lengths,
topk=topk)
return topk_indices return topk_indices
@impl("_C::fast_topkv2", "CUDA") @impl("_C::fast_topkv2", "CUDA")
def fast_topkv2_cuda( def fast_topkv2_cuda(
score: torch.Tensor, score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
lengths: torch.Tensor, ) -> torch.Tensor:
topk: Optional[int] = 2048) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = xtorch_ops.fast_topkv2( topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
score=score,
lengths=lengths,
topk=topk)
return topk_indices return topk_indices
def _fake_fast_topkv2( def _fake_fast_topkv2(
score: torch.Tensor, score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
lengths: torch.Tensor, ) -> torch.Tensor:
topk: Optional[int] = 2048) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now" assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32) topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
return topk_indices return topk_indices
fast_topkv2.register_fake(_fake_fast_topkv2) fast_topkv2.register_fake(_fake_fast_topkv2)
##################################################
# ----------------- LoRA ops --------------------
##################################################
##################################################
# -------------- sgmv_shrink_lora ----------------
##################################################
@custom_op("_C::sgmv_shrink_lora", mutates_args=())
def sgmv_shrink_lora(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> torch.Tensor:
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
# )
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
inputs,
lora_a_weights,
seq_len_tensor,
lora_indices_tensor,
output_tensor,
scaling,
)
@impl("_C::sgmv_shrink_lora", "CUDA")
def sgmv_shrink_lora_cuda(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> torch.Tensor:
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
# )
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
inputs,
lora_a_weights,
seq_len_tensor,
lora_indices_tensor,
output_tensor,
scaling,
)
def _fake_sgmv_shrink_lora(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> torch.Tensor:
return output_tensor
sgmv_shrink_lora.register_fake(_fake_sgmv_shrink_lora)
##################################################
# -------------- sgmv_expand_lora ----------------
##################################################
@custom_op("_C::sgmv_expand_lora", mutates_args=())
def sgmv_expand_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> torch.Tensor:
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
# )
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
)
@impl("_C::sgmv_expand_lora", "CUDA")
def sgmv_expand_lora_cuda(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> torch.Tensor:
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
# )
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
)
def _fake_sgmv_expand_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> torch.Tensor:
return output_tensor
sgmv_expand_lora.register_fake(_fake_sgmv_expand_lora)
##################################################
# ----------- sgmv_expand_slice_lora -------------
##################################################
@custom_op("_C::sgmv_expand_slice_lora", mutates_args=())
def sgmv_expand_slice_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
inputs,
lora_b_weights,
seq_len_tensor,
lora_indices_tensor,
output_tensor,
slice_offset,
)
@impl("_C::sgmv_expand_slice_lora", "CUDA")
def sgmv_expand_slice_lora_cuda(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
inputs,
lora_b_weights,
seq_len_tensor,
lora_indices_tensor,
output_tensor,
slice_offset,
)
def _fake_sgmv_expand_slice_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> torch.Tensor:
return output_tensor
sgmv_expand_slice_lora.register_fake(_fake_sgmv_expand_slice_lora)
##################################################
# -------------- bgmv_shrink_lora ----------------
##################################################
@custom_op("_C::bgmv_shrink_lora", mutates_args=())
def bgmv_shrink_lora(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
)
@impl("_C::bgmv_shrink_lora", "CUDA")
def bgmv_shrink_lora_cuda(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
)
def _fake_bgmv_shrink_lora(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> torch.Tensor:
return output_tensor
bgmv_shrink_lora.register_fake(_fake_bgmv_shrink_lora)
##################################################
# -------------- bgmv_expand_lora ----------------
##################################################
@custom_op("_C::bgmv_expand_lora", mutates_args=())
def bgmv_expand_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
)
@impl("_C::bgmv_expand_lora", "CUDA")
def bgmv_expand_lora_cuda(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
)
def _fake_bgmv_expand_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> torch.Tensor:
return output_tensor
bgmv_expand_lora.register_fake(_fake_bgmv_expand_lora)
##################################################
# ----------- bgmv_expand_slice_lora -------------
##################################################
@custom_op("_C::bgmv_expand_slice_lora", mutates_args=())
def bgmv_expand_slice_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
)
@impl("_C::bgmv_expand_slice_lora", "CUDA")
def bgmv_expand_slice_lora_cuda(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> torch.Tensor:
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
)
def _fake_bgmv_expand_slice_lora(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
block_statistic: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> torch.Tensor:
return output_tensor
bgmv_expand_slice_lora.register_fake(_fake_bgmv_expand_slice_lora)
##################################################
# ----------- lora_matmul_inplace ----------------
##################################################
@custom_op("_C::lora_matmul_inplace", mutates_args=())
def lora_matmul_inplace(
x: torch.Tensor,
w: torch.Tensor,
output_tensor: torch.Tensor,
x_trans: bool = False,
w_trans: bool = True,
alpha: float = 1.0,
beta: float = 1.0,
) -> None:
xtorch_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=output_tensor,
x_trans=x_trans,
w_trans=w_trans,
alpha=alpha,
beta=beta,
)
@impl("_C::lora_matmul_inplace", "CUDA")
def lora_matmul_inplace_cuda(
x: torch.Tensor,
w: torch.Tensor,
output_tensor: torch.Tensor,
x_trans: bool = False,
w_trans: bool = True,
alpha: float = 1.0,
beta: float = 1.0,
) -> None:
xtorch_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=output_tensor,
x_trans=x_trans,
w_trans=w_trans,
alpha=alpha,
beta=beta,
)
def _fake_lora_matmul_inplace(
x: torch.Tensor,
w: torch.Tensor,
output_tensor: torch.Tensor,
x_trans: bool = False,
w_trans: bool = True,
alpha: float = 1.0,
beta: float = 1.0,
) -> None:
return None
lora_matmul_inplace.register_fake(_fake_lora_matmul_inplace)