Compare commits

16 Commits

Author SHA1 Message Date
starkwj
34e04c5569 update base image 2026-03-02 18:46:04 +08:00
starkwj
a15754c3ba add readme 2026-03-02 18:40:49 +08:00
starkwj
4d8575115a add vxpu 2026-03-02 18:38:10 +08:00
lishaobing448
dc63e81a7f fix: use cuda visible (#244)
Signed-off-by: lishaobing448 <shaobingli2024@163.com>
2026-03-02 17:33:13 +08:00
Li Wei
e4c9b9f988 [Bugfix] cocopod ops can't be finded (#242)
Signed-off-by: Li Wei <liwei.109@outlook.com>
2026-03-02 15:49:24 +08:00
Joeegin
171f664a0f [Doc] Update dependencies (#225)
Signed-off-by: Joeegin <3318329726@qq.com>
2026-03-02 10:50:12 +08:00
chanzhennan
82544aa0cc [Feature] Merge branch 'Qwen3-Next' into main && Support Qwen-next (#222)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
Co-authored-by: xyDong0223 <dongxinyu03@baidu.com>
2026-02-28 11:15:50 +08:00
Lidang Jiang
153093d3b3 [Misc] add collect_env feat (#218)
Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
2026-02-27 12:19:58 +08:00
Xinyu Dong
d425a0d0e9 [Docs] Add vLLM-Kunlun New Model Adaptation Manual and Update Model Support (#211)
* [Docs] Fix app.readthedocs buliding

Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>

* [Docs] Add vLLM-Kunlun New Model Adaptation Manual and Update Model Support

Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
2026-02-26 10:06:58 +08:00
Shiwen Tang
b82b6026d6 [BugFix] Adapt GLM5 config for transformers 4.57 (#207)
Signed-off-by: tangshiwen <tangshiwen@baidu.com>
2026-02-25 18:47:26 +08:00
Xinyu Dong
a470452871 [Docs] Fix app.readthedocs buliding (#210)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
2026-02-17 16:17:25 +08:00
Xinyu Dong
d9ad42a174 [Docs] Fix quantization support description in README (#208)
Updated quantization support description from FP8 to INT8.
2026-02-15 13:12:17 +08:00
Xinyu Dong
77dbc2ddeb [Docs] Update README (#206)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
2026-02-15 11:05:54 +08:00
Xinyu Dong
76ec220b43 [Bugsfix] Fix run failed (#198)
Signed-off-by: xyDong0223 <dongxinyu03@baidu.com>
2026-02-13 14:07:10 +08:00
Xinyu Dong
bf9369f733 Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
2026-02-12 18:13:00 +08:00
Li Wei
744719587e [Feature] Support glmx (#194)
Signed-off-by: Li Wei <liwei.109@outlook.com>
Co-authored-by: tangshiwen <tangshiwen@baidu.com>
Co-authored-by: Xinyu Dong <dongxinyu03@baidu.com>
2026-02-12 15:40:42 +08:00
47 changed files with 5326 additions and 2204 deletions

View File

@@ -1,4 +1,4 @@
FROM wjie520/vllm_kunlun:base_v0.0.2
FROM vllm_kunlun:custom_base_v0.0.3
WORKDIR /workspace

View File

@@ -11,11 +11,12 @@ One of the key features of this project is efficient memory coordination, enabli
### Build from Dockerfile
Clone this repository:
1. Get or build base image (base with customized xpytorch, ops, etc.). Ref: [installation](https://vllm-kunlun.readthedocs.io/en/latest/installation.html).
```bash
docker build -t $build_image -f ./Dockerfile .
```
2. Clone this repository and build
```bash
docker build -t $build_image -f ./Dockerfile .
```
## Usage
@@ -25,4 +26,4 @@ docker build -t $build_image -f ./Dockerfile .
### Environment Variables
- `VXPU_RESERVED_VRAM_SIZE_GB`: The amonut of reserved GPU memory for other miscellaneous memory. Only needs to be set for `vllm_vxpu_daemon`. Try increasing the variable if you launch multiple LLM services and encounter OOM. Default: `8`.
- `VLLM_VXPU_SHM_NAME`: The name of the shm file. Needs to be set for all containers of the shared vxpu group. Default: `/vllm_kunlun_vxpu_offload_shm`.
- `VLLM_VXPU_SHM_NAME`: The name of the shm file. Needs to be set for all containers of the shared vxpu group. Default: `/vllm_kunlun_vxpu_offload_shm`.

View File

@@ -1,212 +1,199 @@
![vLLM Kunlun Logo](vllm_kunlun/patches/vLLM_Kunlun.jpg)
<p align="center">
<a href="https://vllm-kunlun.readthedocs.io/en/latest/"><b> Documentation</b></a> |
<a href="https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html"><b> Quick Start</b></a> |
<a href="https://join.slack.com/t/vllm-kunlun/shared_invite/zt-3iinb8u5z-FcqZKbNNdMJ_32fHmipzvw"><b> Slack</b></a>
<a href="https://vllm-kunlun.readthedocs.io/en/latest/"><b>📖 Documentation</b></a> |
<a href="https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html"><b>🚀 Quick Start</b></a> |
<a href="https://vllm-kunlun.readthedocs.io/en/latest/installation.html"><b>📦 Installation</b></a> |
<a href="https://join.slack.com/t/vllm-kunlun/shared_invite/zt-3iinb8u5z-FcqZKbNNdMJ_32fHmipzvw"><b>💬 Slack</b></a>
</p>
<p align="center">
<img alt="GitHub License" src="https://img.shields.io/github/license/baidu/vLLM-Kunlun">
<img alt="GitHub Stars" src="https://img.shields.io/github/stars/baidu/vLLM-Kunlun">
<img alt="GitHub Forks" src="https://img.shields.io/github/forks/baidu/vLLM-Kunlun">
<img alt="GitHub Issues" src="https://img.shields.io/github/issues/baidu/vLLM-Kunlun">
<img alt="Python Version" src="https://img.shields.io/badge/python-%3E%3D3.10-blue">
</p>
---
## Latest News 🔥
- [2025/12] Initial release of vLLM Kunlun
- [2026/02] 🧠 **GLM model family support** — Added GLM5, GLM-4.7 MTP (Multi-Token Prediction), and GLM-47 tool parser with thinking/non-thinking mode toggle
- [2026/02] ⚡ **Performance optimizations** — Fused MoE with small batches, optimized attention metadata building, Multi-LoRA inference achieves 80%+ of non-LoRA performance
- [2026/02] 🔧 **DeepSeek-V3.2 MTP support** — Added MTP (Multi-Token Prediction) for DeepSeek-V3.2, with RoPE and decoding stage kernel optimizations
- [2026/01] 🔢 **New quantization methods** — Support for compressed-tensors W4A16, AWQ MoE W4A16, and DeepSeek-V3.2 W8A8 quantization
- [2026/01] 🛠️ **CI/CD overhaul** — Added E2E tests, unit test CI, ruff format checks, and modular CI workflow refactoring
- [2025/12] 🎉 **v0.11.0rc1 released** — Added Qwen3-Omni, Qwen3-Next, Seed-OSS support ([Release Notes](https://github.com/baidu/vLLM-Kunlun/releases/tag/v0.11.0rc1))
- [2025/12] 📦 **v0.10.1.1 released** — 5+ multimodal models, AWQ/GPTQ quantization for dense models, Piecewise CUDA Graph, vLLM V1 engine, Flash-Infer Top-K/Top-P sampling with 10-100× speedup ([Release Notes](https://github.com/baidu/vLLM-Kunlun/releases/tag/v0.10.1.1))
- [2025/12] 🌟 Initial release of vLLM Kunlun — Open sourced on Dec 8, 2025
---
# Overview
## Overview
vLLM Kunlun (vllm-kunlun) is a community-maintained hardware plugin designed to seamlessly run vLLM on the Kunlun XPU. It is the recommended approach for integrating the Kunlun backend within the vLLM community, adhering to the principles outlined in the [RFC Hardware pluggable](https://github.com/vllm-project/vllm/issues/11162). This plugin provides a hardware-pluggable interface that decouples the integration of the Kunlun XPU with vLLM.
**vLLM Kunlun** (`vllm-kunlun`) is a community-maintained hardware plugin designed to seamlessly run [vLLM](https://github.com/vllm-project/vllm) on the **Kunlun XPU**. It is the recommended approach for integrating the Kunlun backend within the vLLM community, adhering to the principles outlined in the [RFC Hardware Pluggable](https://github.com/vllm-project/vllm/issues/11162).
By utilizing the vLLM Kunlun plugin, popular open-source models, including Transformer-like, Mixture-of-Expert, Embedding, and Multi-modal LLMs, can run effortlessly on the Kunlun XPU.
This plugin provides a hardware-pluggable interface that decouples the integration of the Kunlun XPU with vLLM. By utilizing vLLM Kunlun, popular open-source models including Transformer-like, Mixture-of-Expert (MoE), Embedding, and Multi-modal LLMs can run effortlessly on the Kunlun XPU.
### ✨ Key Features
- **Seamless Plugin Integration** — Works as a standard vLLM platform plugin via Python entry points, no need to modify vLLM source code
- **Broad Model Support** — Supports 15+ mainstream LLMs including Qwen, Llama, DeepSeek, Kimi-K2, and multimodal models
- **Quantization Support** — INT8 and other quantization methods for MoE and dense models
- **LoRA Fine-Tuning** — LoRA adapter support for Qwen series models
- **Piecewise Kunlun Graph** — Hardware-accelerated graph optimization for high-performance inference
- **FlashMLA Attention** — Optimized multi-head latent attention for DeepSeek MLA architectures
- **Tensor Parallelism** — Multi-device parallel inference with distributed execution support
- **OpenAI-Compatible API** — Serve models with the standard OpenAI API interface
---
## Prerequisites
- **Hardware**: Kunlun3 P800
- **OS**: Ubuntu 22.04
- **Hardware**: Kunlun3 P800
- **OS**: Ubuntu 22.04
- **Software**:
- Python >=3.10
- PyTorch 2.5.1
- Python >= 3.10
- PyTorch >= 2.5.1
- vLLM (same version as vllm-kunlun)
- transformers >= 4.57.0
---
## Supported Models
<h3>Generaltive Models</h3>
<table>
<thead>
<tr>
<th width="30%">Model</th>
<th width="12%">Support</th>
<th width="15%">Quantization</th>
<th width="10%">LoRA</th>
<th width="20%">Piecewise Kunlun Graph</th>
<th width="23%">Note</th>
</tr>
</thead>
<tbody>
<tr>
<td class="model-name">Qwen2</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen2.5</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen3</td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen3-Moe</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Qwen3-Next</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">MiMo-V2-Flash</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Llama2</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Llama3</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Llama3.1</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">gpt-oss</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td></td>
<td></td>
</tr>
<tr>
<td class="model-name">DeepSeek-R1</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">DeepSeek-V3</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">DeepSeek-V3.2</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
<tr>
<td class="model-name">Kimi-K2</td>
<td class="status-support"></td>
<td class="status-support"></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
</tbody>
</table>
### Generative Models
<h3>Multimodal Language Models</h3>
<table>
<thead>
<tr>
<th width="20%">Model</th>
<th width="12%">Support</th>
<th width="15%">Quantization</th>
<th width="10%">LoRA</th>
<th width="20%">Piecewise Kunlun Graph</th>
<th width="23%">Note</th>
</tr>
</thead>
<tbody>
<tr>
<td class="model-name">Qwen3-VL</td>
<td class="status-support"></td>
<td></td>
<td></td>
<td class="status-support"></td>
<td></td>
</tr>
</tbody>
</table>
| Model | Support | Quantization | LoRA | Kunlun Graph |
|:------|:-------:|:------------:|:----:|:----------------------:|
| Qwen2 | ✅ | ✅| ✅ | ✅ |
| Qwen2.5 | ✅ |✅ | ✅ | ✅ |
| Qwen3 | ✅ |✅ | ✅ | ✅ |
| Qwen3-Moe | ✅ | ✅ | | ✅ |
| Qwen3-Next | ✅ | ✅ | | ✅ |
| MiMo-V2-Flash | ✅ | ✅| | ✅ |
| Llama2 | ✅ | ✅| ✅| ✅ |
| Llama3 | ✅ |✅ | ✅ | ✅ |
| Llama3.1 | ✅ |✅ | | ✅ |
| gpt-oss | ✅ | ✅| | |
| GLM4.5 | ✅ | ✅| | ✅ |
| GLM4.5Air | ✅ |✅ | | ✅ |
| GLM4.7 | ✅ | ✅| | ✅ |
| GLM5 | ✅ | ✅| | ✅ |
| Kimi-K2 | ✅ | ✅ | | ✅ |
| DeepSeek-R1 | ✅ | ✅ | | ✅ |
| DeepSeek-V3 | ✅ | ✅ | | ✅ |
| DeepSeek-V3.2 | ✅ | ✅ | | ✅ |
### Multimodal Language Models
| Model | Support | Quantization | LoRA | Kunlun Graph |
|:------|:-------:|:------------:|:----:|:----------------------:|
| Qwen2-VL | ✅ | ✅| | ✅ |
| Qwen2.5-VL | ✅ | ✅| | ✅ |
| Qwen3-VL | ✅ | ✅| | ✅ |
| Qwen3-VL-MoE | ✅ | ✅ | | ✅ |
| Qwen3-Omni-MoE | ✅ | | | ✅ |
| InternVL-2.5 | ✅ | | | ✅ |
| InternVL-3.5 | ✅ | | | ✅ |
| InternS1 | ✅ | | | ✅ |
---
## Performance Visualization 🚀
### High-performance computing at work: How different models perform on the Kunlun3 P800.
Current environment: 16-way concurrency, input/output size 2048.
![Models and tgs](./vllm_kunlun/patches/performance.png)
## Getting Started
Please use the following recommended versions to get started quickly:
| Version | Release type | Doc |
|----------|---------------|-----|
| v0.11.0 | Latest stable version | [QuickStart](https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html) and [Installation](https://vllm-kunlun.readthedocs.io/en/latest/installation.html) for more details |
---
## Contribute to vLLM Kunlun
### Quick Start
If you're interested in contributing to this project, please read [Contributing](CONTRIBUTING.md) to vLLM Kunlun.
#### Start an OpenAI-Compatible API Server
```bash
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8356 \
--model <your-model-path> \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 1 \
--dtype float16 \
--max_num_seqs 128 \
--max_num_batched_tokens 32768 \
--block-size 128 \
--distributed-executor-backend mp \
--served-model-name <your-model-name>
```
#### Send a Request
```bash
curl http://localhost:8356/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "<your-model-name>",
"messages": [{"role": "user", "content": "Hello!"}],
"max_tokens": 512
}'
```
### Version Matrix
| Version | Release Type | Documentation |
|---------|:------------:|:-------------:|
| v0.11.0 | Latest stable version | [Quick Start](https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html) · [Installation](https://vllm-kunlun.readthedocs.io/en/latest/installation.html) |
---
## Architecture
```
vllm-kunlun/
├── vllm_kunlun/ # Core plugin package
│ ├── platforms/ # Kunlun XPU platform implementation
│ ├── models/ # Model implementations (DeepSeek, Qwen, Llama, etc.)
│ ├── ops/ # Custom operators (attention, linear, sampling, etc.)
│ │ ├── attention/ # FlashMLA, paged attention, merge attention states
│ │ ├── fla/ # Flash linear attention operations
│ │ └── sample/ # Sampling operators
│ ├── v1/ # vLLM V1 engine adaptations
│ ├── compilation/ # Torch compile wrapper for Kunlun Graph
│ ├── csrc/ # C++ extensions (custom CUDA-compatible kernels)
│ └── config/ # Model configuration overrides
├── tests/ # Test suite
├── docs/ # Documentation (Sphinx-based, ReadTheDocs hosted)
├── ci/ # CI pipeline configurations
├── setup.py # Legacy build script (with C++ extensions)
└── pyproject.toml # Modern Python build configuration (hatchling)
```
---
## Contributing
We welcome contributions from the community! Please read our [Contributing Guide](CONTRIBUTING.md) before submitting a PR.
### PR Classification
Use the following prefixes for PR titles:
- `[Attention]` — Attention mechanism features/optimizations
- `[Core]` — Core vllm-kunlun logic (platform, attention, communicators, model runner)
- `[Kernel]` — Compute kernels and ops
- `[Bugfix]` — Bug fixes
- `[Doc]` — Documentation improvements
- `[Test]` — Tests
- `[CI]` — CI/CD improvements
- `[Misc]` — Other changes
---
## Star History 🔥
@@ -214,10 +201,14 @@ We opened the project at Dec 8, 2025. We love open source and collaboration ❤
[![Star History Chart](https://api.star-history.com/svg?repos=baidu/vLLM-Kunlun&type=date&legend=bottom-right)](https://www.star-history.com/#baidu/vLLM-Kunlun&type=date&legend=bottom-right)
---
## Sponsors 👋
We sincerely appreciate the [**KunLunXin**](https://www.kunlunxin.com/) team for their support in providing XPU resources, which enabled efficient model adaptation debugging, comprehensive end-to-end testing, and broader model compatibility.
---
## License
Apache License 2.0, as found in the [LICENSE](./LICENSE) file.
Apache License 2.0, as found in the [LICENSE](./LICENSE) file.

695
collect_env.py Normal file
View File

@@ -0,0 +1,695 @@
# SPDX-License-Identifier: Apache-2.0
# vLLM-Kunlun Environment Information Collection Tool (Fixed Version)
"""
Environment information collection script for Kunlun XPU
Fixed the following issues:
1. Device name displayed as "GPU" → Now correctly shows "P800 OAM"
2. XRE version command error → Now parsed from xpu-smi output
3. vLLM-Kunlun version hardcoded → Now fetched from pip package metadata
"""
import os
import re
import subprocess
import sys
from collections import namedtuple
# =============================================================================
# Part 1: Basic Utility Functions
# =============================================================================
def run(command):
"""
Execute shell command and return result
[Principle Explanation - Web Development Analogy]
This is like the fetch() function in frontend development, sending a request and getting a response.
- command: The command to execute (similar to a URL)
- returns: (return_code, stdout, stderr)
Args:
command: Command as string or list
Returns:
tuple: (return_code, stdout, stderr)
"""
shell = True if isinstance(command, str) else False
try:
p = subprocess.Popen(
command,
stdout=subprocess.PIPE, # Capture standard output
stderr=subprocess.PIPE, # Capture error output
shell=shell,
)
raw_output, raw_err = p.communicate()
rc = p.returncode
# Decode byte stream to string
output = raw_output.decode("utf-8").strip()
err = raw_err.decode("utf-8").strip()
return rc, output, err
except FileNotFoundError:
return 127, "", "Command not found"
def run_and_read_all(run_lambda, command):
"""Execute command, return output if successful, None otherwise"""
rc, out, _ = run_lambda(command)
if rc != 0:
return None
return out
def run_and_parse_first_match(run_lambda, command, regex):
"""Execute command and extract first regex match"""
rc, out, _ = run_lambda(command)
if rc != 0:
return None
match = re.search(regex, out)
if match is None:
return None
return match.group(1)
# Check if PyTorch is available
try:
import torch
TORCH_AVAILABLE = True
except (ImportError, NameError, AttributeError, OSError):
TORCH_AVAILABLE = False
# =============================================================================
# Part 2: General System Information Collection (Reusing vLLM Original Logic)
# =============================================================================
def get_platform():
"""Get operating system platform"""
if sys.platform.startswith("linux"):
return "linux"
elif sys.platform.startswith("win32"):
return "win32"
elif sys.platform.startswith("darwin"):
return "darwin"
return sys.platform
def get_os(run_lambda):
"""Get detailed operating system information"""
from platform import machine
if get_platform() == "linux":
# Try reading /etc/*-release
rc, out, _ = run_lambda(
"cat /etc/*-release 2>/dev/null | grep PRETTY_NAME | head -1"
)
if rc == 0 and out:
match = re.search(r'PRETTY_NAME="(.*)"', out)
if match:
return f"{match.group(1)} ({machine()})"
# Fallback: use lsb_release
rc, out, _ = run_lambda("lsb_release -d 2>/dev/null")
if rc == 0 and out:
match = re.search(r"Description:\s*(.*)", out)
if match:
return f"{match.group(1)} ({machine()})"
return f"{get_platform()} ({machine()})"
def get_gcc_version(run_lambda):
"""Get GCC version"""
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
def get_clang_version(run_lambda):
"""Get Clang version"""
return run_and_parse_first_match(
run_lambda, "clang --version", r"clang version (.*)"
)
def get_cmake_version(run_lambda):
"""Get CMake version"""
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
def get_libc_version():
"""Get libc version"""
import platform
if get_platform() != "linux":
return "N/A"
return "-".join(platform.libc_ver())
def get_python_platform():
"""Get Python platform information"""
import platform
return platform.platform()
def get_cpu_info(run_lambda):
"""Get CPU information"""
if get_platform() == "linux":
rc, out, err = run_lambda("lscpu")
return out if rc == 0 else err
return "N/A"
def get_pip_packages(run_lambda, patterns=None):
"""Get pip package list"""
if patterns is None:
patterns = {
"torch",
"numpy",
"triton",
"transformers",
"vllm",
"kunlun",
"xpu",
"bkcl",
"xmlir",
}
cmd = [sys.executable, "-mpip", "list", "--format=freeze"]
out = run_and_read_all(run_lambda, cmd)
if out is None:
return "pip3", ""
filtered = "\n".join(
line
for line in out.splitlines()
if any(name.lower() in line.lower() for name in patterns)
)
pip_version = "pip3" if sys.version[0] == "3" else "pip"
return pip_version, filtered
def get_conda_packages(run_lambda, patterns=None):
"""Get conda package list"""
if patterns is None:
patterns = {
"torch",
"numpy",
"triton",
"transformers",
"kunlun",
"xpu",
"bkcl",
"xmlir",
}
conda = os.environ.get("CONDA_EXE", "conda")
out = run_and_read_all(run_lambda, [conda, "list"])
if out is None:
return None
return "\n".join(
line
for line in out.splitlines()
if not line.startswith("#")
and any(name.lower() in line.lower() for name in patterns)
)
# =============================================================================
# Part 3: Kunlun-Specific Information Collection (Core Fix)
# =============================================================================
def parse_xpu_smi_output(run_lambda):
"""
Parse the complete output of xpu-smi command
[Principle Explanation]
The xpu-smi output format is similar to nvidia-smi, we need to parse it with regex.
Example output format:
+-----------------------------------------------------------------------------+
| XPU-SMI Driver Version: 515.58 XPU-RT Version: N/A |
|-------------------------------+----------------------+----------------------+
| 0 P800 OAM N/A | 00000000:52:00.0 N/A | 0 |
| N/A 43C N/A 85W / 400W | 4MiB / 98304MiB | 0% Default |
Returns:
dict: Dictionary containing parsing results
"""
rc, output, _ = run_lambda("xpu-smi")
if rc != 0 or not output:
return None
result = {
"raw_output": output,
"driver_version": None,
"xre_version": None,
"devices": [],
}
# Parse header: Driver Version and XPU-RT Version
# Format: | XPU-SMI Driver Version: 515.58 XPU-RT Version: N/A |
header_match = re.search(
r"Driver Version:\s*(\S+)\s+XPU-RT Version:\s*(\S+)", output
)
if header_match:
result["driver_version"] = header_match.group(1)
xre = header_match.group(2)
result["xre_version"] = xre if xre != "N/A" else None
# Parse device information
# Format: | 0 P800 OAM N/A | 00000000:52:00.0 N/A |
# Following: | N/A 43C N/A 85W / 400W | 4MiB / 98304MiB |
# Find all device lines (containing device ID and name)
device_pattern = re.compile(
r"\|\s*(\d+)\s+(\S+(?:\s+\S+)?)\s+(?:N/A|On|Off)\s*\|" # ID and Name
r"\s*([0-9a-fA-F:\.]+)\s*" # Bus-Id
)
# Find memory information
memory_pattern = re.compile(
r"\|\s*N/A\s+\d+C\s+N/A\s+\d+W\s*/\s*\d+W\s*\|"
r"\s*(\d+)MiB\s*/\s*(\d+)MiB\s*\|" # Memory-Usage / Total
)
lines = output.split("\n")
i = 0
while i < len(lines):
line = lines[i]
device_match = device_pattern.search(line)
if device_match:
device_id = int(device_match.group(1))
device_name = device_match.group(2).strip()
bus_id = device_match.group(3)
# Next line should have memory info
memory_used = 0
memory_total = 0
if i + 1 < len(lines):
mem_match = memory_pattern.search(lines[i + 1])
if mem_match:
memory_used = int(mem_match.group(1))
memory_total = int(mem_match.group(2))
result["devices"].append(
{
"id": device_id,
"name": device_name, # This will correctly get "P800 OAM"
"bus_id": bus_id,
"memory_used_mib": memory_used,
"memory_total_mib": memory_total,
}
)
i += 1
return result
def get_kunlun_gpu_info(run_lambda):
"""
Get Kunlun XPU device information
[Fix Explanation]
Previously used torch.cuda.get_device_properties() to get the name,
but it only returns "GPU" (because Kunlun masquerades as CUDA).
Now parse xpu-smi output to correctly get "P800 OAM".
Returns:
str: Device information string
"""
parsed = parse_xpu_smi_output(run_lambda)
if parsed and parsed["devices"]:
# Get real device name from xpu-smi parsing
lines = []
for dev in parsed["devices"]:
memory_gb = dev["memory_total_mib"] / 1024
# Correctly display: XPU 0: P800 OAM (96.0GB)
lines.append(f"XPU {dev['id']}: {dev['name']} ({memory_gb:.1f}GB)")
return "\n".join(lines)
# Fallback: Use PyTorch interface (but will display as GPU)
if TORCH_AVAILABLE:
try:
device_count = torch.cuda.device_count()
lines = []
for i in range(device_count):
props = torch.cuda.get_device_properties(i)
name = props.name if hasattr(props, "name") else "Kunlun XPU"
memory_gb = (
props.total_memory / (1024**3)
if hasattr(props, "total_memory")
else 0
)
lines.append(f"XPU {i}: {name} ({memory_gb:.1f}GB)")
return "\n".join(lines)
except Exception as e:
return f"Error: {e}"
return None
def get_kunlun_driver_version(run_lambda):
"""
Get Kunlun driver version
[Fix Explanation]
Parse directly from xpu-smi output header instead of calling incorrect commands.
Returns:
str: Driver version, e.g., "515.58"
"""
parsed = parse_xpu_smi_output(run_lambda)
if parsed and parsed["driver_version"]:
return parsed["driver_version"]
return None
def get_kunlun_xre_version(run_lambda):
"""
Get Kunlun XRE (Runtime) version
[Fix Explanation]
Previously used `xpu-smi --version` but that parameter doesn't exist.
Now parse the "XPU-RT Version" field from xpu-smi standard output header.
Returns:
str: XRE version, or None (if showing N/A)
"""
parsed = parse_xpu_smi_output(run_lambda)
if parsed and parsed["xre_version"]:
return parsed["xre_version"]
return "N/A (not installed or not detected)"
def get_kunlun_topo(run_lambda):
"""
Get Kunlun XPU topology information
Returns:
str: Topology information
"""
# xpu-smi topo -m command can get topology
output = run_and_read_all(run_lambda, "xpu-smi topo -m")
if output:
return output
# Fallback: Show device count
if TORCH_AVAILABLE:
try:
count = torch.cuda.device_count()
return f"Detected {count} Kunlun XPU device(s)"
except Exception:
pass
return None
def get_bkcl_version(run_lambda):
"""
Get BKCL (communication library) version
[Principle Explanation]
BKCL = Baidu Kunlun Communication Library
Similar to NVIDIA's NCCL, used for multi-card communication.
Returns:
str: BKCL version information
"""
# Method 1: From your logs, BKCL prints version when loading
# [WARN][BKCL][globals.cpp:268] xccl version: 6ab4ffb [rdma] ...
# We can try importing related modules
try:
# Try getting from torch_xmlir
import torch_xmlir
# Find path to libbkcl.so
bkcl_path = None
if hasattr(torch_xmlir, "__file__"):
import os
base = os.path.dirname(torch_xmlir.__file__)
candidate = os.path.join(base, "libbkcl.so")
if os.path.exists(candidate):
bkcl_path = candidate
if bkcl_path:
return f"Found at: {bkcl_path}"
except ImportError:
pass
# Method 2: Search from ldconfig
rc, out, _ = run_lambda("ldconfig -p 2>/dev/null | grep -i bkcl | head -1")
if rc == 0 and out:
return out
return None
def get_vllm_kunlun_version():
"""
Get vLLM-Kunlun version
[Fix Explanation]
Previously got hardcoded version "0.9.2" from vllm_kunlun.platforms.version,
but actual pip installed version is "0.1.0".
Now prioritize using importlib.metadata to get real installed version.
Returns:
str: Version number
"""
# Method 1 (recommended): Use importlib.metadata (Python 3.8+)
try:
from importlib.metadata import version
return version("vllm-kunlun")
except Exception:
pass
# Method 2: Use pkg_resources
try:
import pkg_resources
return pkg_resources.get_distribution("vllm-kunlun").version
except Exception:
pass
# Method 3 (fallback): Get from code (may be inaccurate)
try:
from vllm_kunlun.platforms.version import get_xvllm_version
return get_xvllm_version() + " (from code, may be inaccurate)"
except ImportError:
pass
return "N/A"
def get_vllm_version():
"""Get vLLM main package version"""
try:
from importlib.metadata import version
return version("vllm")
except Exception:
pass
try:
from vllm import __version__
return __version__
except ImportError:
pass
return "N/A"
# =============================================================================
# Part 4: Environment Variable Collection
# =============================================================================
def get_kunlun_env_vars():
"""Get Kunlun-related environment variables"""
env_vars = ""
kunlun_prefixes = (
"XPU",
"KUNLUN",
"BKCL",
"XCCL",
"XRE",
"TORCH",
"VLLM",
)
secret_terms = ("secret", "token", "api", "access", "password")
for k, v in sorted(os.environ.items()):
if any(term in k.lower() for term in secret_terms):
continue
if any(k.upper().startswith(prefix) for prefix in kunlun_prefixes):
env_vars += f"{k}={v}\n"
return env_vars
# =============================================================================
# Part 5: Define Data Structure and Formatted Output
# =============================================================================
KunlunSystemEnv = namedtuple(
"KunlunSystemEnv",
[
# General system information
"os",
"gcc_version",
"clang_version",
"cmake_version",
"libc_version",
"python_version",
"python_platform",
"pip_version",
"pip_packages",
"conda_packages",
"cpu_info",
# PyTorch information
"torch_version",
"is_debug_build",
# Kunlun-specific information
"kunlun_xpu_info",
"kunlun_driver_version",
"kunlun_xre_version",
"bkcl_version",
"kunlun_topo",
# vLLM related
"vllm_version",
"vllm_kunlun_version",
"env_vars",
],
)
def get_kunlun_env_info():
"""Collect all environment information"""
run_lambda = run
pip_version, pip_list_output = get_pip_packages(run_lambda)
# PyTorch information
if TORCH_AVAILABLE:
torch_version = torch.__version__
debug_mode_str = str(torch.version.debug)
else:
torch_version = "N/A"
debug_mode_str = "N/A"
sys_version = sys.version.replace("\n", " ")
return KunlunSystemEnv(
# General system information
os=get_os(run_lambda),
gcc_version=get_gcc_version(run_lambda),
clang_version=get_clang_version(run_lambda),
cmake_version=get_cmake_version(run_lambda),
libc_version=get_libc_version(),
python_version=f"{sys_version} ({sys.maxsize.bit_length() + 1}-bit runtime)",
python_platform=get_python_platform(),
pip_version=pip_version,
pip_packages=pip_list_output,
conda_packages=get_conda_packages(run_lambda),
cpu_info=get_cpu_info(run_lambda),
# PyTorch information
torch_version=torch_version,
is_debug_build=debug_mode_str,
# Kunlun-specific information
kunlun_xpu_info=get_kunlun_gpu_info(run_lambda),
kunlun_driver_version=get_kunlun_driver_version(run_lambda),
kunlun_xre_version=get_kunlun_xre_version(run_lambda),
bkcl_version=get_bkcl_version(run_lambda),
kunlun_topo=get_kunlun_topo(run_lambda),
# vLLM related
vllm_version=get_vllm_version(),
vllm_kunlun_version=get_vllm_kunlun_version(),
env_vars=get_kunlun_env_vars(),
)
# Output format template
kunlun_env_info_fmt = """
==============================
System Info
==============================
OS : {os}
GCC version : {gcc_version}
Clang version : {clang_version}
CMake version : {cmake_version}
Libc version : {libc_version}
==============================
PyTorch Info
==============================
PyTorch version : {torch_version}
Is debug build : {is_debug_build}
==============================
Python Environment
==============================
Python version : {python_version}
Python platform : {python_platform}
==============================
Kunlun / XPU Info
==============================
XPU models and configuration :
{kunlun_xpu_info}
Kunlun driver version : {kunlun_driver_version}
XRE (Runtime) version : {kunlun_xre_version}
BKCL version : {bkcl_version}
XPU Topology:
{kunlun_topo}
==============================
CPU Info
==============================
{cpu_info}
==============================
Versions of relevant libraries
==============================
{pip_packages}
{conda_packages}
==============================
vLLM-Kunlun Info
==============================
vLLM Version : {vllm_version}
vLLM-Kunlun Version : {vllm_kunlun_version}
==============================
Environment Variables
==============================
{env_vars}
""".strip()
def pretty_str(envinfo):
"""Format environment information"""
mutable_dict = envinfo._asdict()
# Replace None with "Could not collect"
for key in mutable_dict:
if mutable_dict[key] is None:
mutable_dict[key] = "Could not collect"
# Handle pip package list
if mutable_dict["pip_packages"]:
mutable_dict["pip_packages"] = "\n".join(
f"[{envinfo.pip_version}] {line}"
for line in mutable_dict["pip_packages"].split("\n")
if line
)
else:
mutable_dict["pip_packages"] = "No relevant packages"
# Handle conda package list
if mutable_dict["conda_packages"]:
mutable_dict["conda_packages"] = "\n".join(
f"[conda] {line}"
for line in mutable_dict["conda_packages"].split("\n")
if line
)
else:
mutable_dict["conda_packages"] = ""
return kunlun_env_info_fmt.format(**mutable_dict)
def get_pretty_kunlun_env_info():
"""Get formatted environment information"""
return pretty_str(get_kunlun_env_info())
def main():
"""Main entry point"""
print("Collecting Kunlun XPU environment information...")
output = get_pretty_kunlun_env_info()
print(output)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,760 @@
# 📖 vLLM-Kunlun New Model Adaptation Manual
> Based on in-depth analysis of [baidu/vLLM-Kunlun](https://github.com/baidu/vLLM-Kunlun) and [vllm-project/vllm](https://github.com/vllm-project/vllm) repositories.
>
> Applicable Versions: vLLM v0.15.1+ / vLLM-Kunlun main branch
---
## Table of Contents
- [I. Understanding the Overall Architecture](#i-understanding-the-overall-architecture)
- [1.1 Plugin System](#11-plugin-system)
- [1.2 Startup Process](#12-startup-process)
- [1.3 Import Hook Mechanism](#13-import-hook-mechanism)
- [1.4 Code Architecture](#14-code-architecture)
- [II. New Model Adaptation Step-by-Step](#ii-new-model-adaptation-step-by-step)
- [Step 0: Pre-assessment](#step-0-pre-assessment)
- [Step 1: Implement Model Files](#step-1-implement-model-files)
- [Step 2: Register the Model](#step-2-register-the-model)
- [Step 3: Verify Registration](#step-3-verify-registration)
- [Step 4: Testing](#step-4-testing)
- [III. Adaptation Guide for Special Model Types](#iii-adaptation-guide-for-special-model-types)
- [3.1 MoE Models](#31-moe-models-eg-qwen3-moe-deepseek-v3)
- [3.2 MLA Models](#32-mla-multi-latent-attention-models-eg-deepseek-v3)
- [3.3 Multi-modal Models](#33-multi-modal-models-eg-qwen2-vl-internvl)
- [3.4 Hybrid Attention Models](#34-hybrid-attention-models-eg-qwen3-next)
- [IV. Quantized Model Adaptation](#iv-quantized-model-adaptation)
- [4.1 Supported Quantization Methods](#41-supported-quantization-methods)
- [4.2 Special Handling for Quantization](#42-special-handling-for-quantization)
- [V. Custom Operators](#v-custom-operators-if-new-low-level-ops-are-needed)
- [VI. Common Pitfalls Checklist](#vi-common-pitfalls-checklist)
- [VII. Reference Template Quick Look-up](#vii-reference-template-quick-look-up)
- [VIII. Debugging Tips](#viii-debugging-tips)
- [IX. Environment Variables Cheat Sheet](#ix-environment-variables-cheat-sheet)
- [X. PR Submission Standards](#x-pr-submission-standards)
---
## I. Understanding the Overall Architecture
### 1.1 Plugin System
vLLM-Kunlun uses the **OOT (Out-of-Tree) Plugin** approach to integrate with vLLM, primarily registered via `entry_points` in `setup.py`:
```python
# setup.py
entry_points={
'vllm.platform_plugins': ["kunlun = vllm_kunlun:register"], # Platform Plugin
'vllm.general_plugins': [
"kunlun_model = vllm_kunlun:register_model", # Model Registration
"kunlun_quant = vllm_kunlun:register_quant_method" # Quantization Method
],
"console_scripts": [
"vllm_kunlun = vllm_kunlun.entrypoints.main:main"
]
}
```
### 1.2 Startup Process
```
vllm Startup
├─ 1. Discover platform_plugin → Call vllm_kunlun:register()
│ ├─ Register KunlunPlatform (defines Attention Backend, Worker, etc.)
│ ├─ Apply import hook (module redirection)
│ └─ Register custom operators (custom_op)
├─ 2. Discover general_plugin → Call vllm_kunlun:register_model()
│ └─ Register all Kunlun-adapted models via ModelRegistry.register_model()
└─ 3. Model Loading → Match registered model classes based on the architectures field in config.json
```
### 1.3 Import Hook Mechanism
vLLM-Kunlun uses a custom import hook to **transparently replace** certain vLLM modules with Kunlun-customized versions:
```python
# vllm_kunlun/__init__.py
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
try:
module_mappings = {
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
"vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils",
"vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader",
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler",
"vllm.attention.ops.merge_attn_states": "vllm_kunlun.ops.attention.merge_attn_states",
}
if module_name in module_mappings:
if module_name in sys.modules:
return sys.modules[module_name]
target_module = module_mappings[module_name]
module = importlib.import_module(target_module)
sys.modules[module_name] = module
sys.modules[target_module] = module
except Exception:
pass
return OLD_IMPORT_HOOK(module_name, globals=globals, locals=locals, fromlist=fromlist, level=level)
```
> **⚠️ Understanding this mechanism is crucial**: Even if you use `from vllm.xxx import YYY` in your model code, what you actually get might be `vllm_kunlun.xxx.YYY`.
### 1.4 Code Architecture
```
vllm_kunlun/
├── __init__.py # Plugin Entry: register() + import_hook()
├── platforms/kunlun.py # KunlunPlatform: Defines Attention Backend, Worker, etc.
├── models/ # ⭐ Model Implementation Directory (where you add files)
│ ├── __init__.py # ⭐ Model Registration Entry
│ ├── deepseek_v2.py # DeepSeek V2/V3 Reference Implementation
│ ├── deepseek_mtp.py # DeepSeek MTP (Speculative Decoding)
│ ├── qwen3.py # Qwen3 Reference Implementation (Dense Model)
│ ├── qwen3_moe.py # Qwen3 MoE Reference Implementation
│ ├── qwen3_next.py # Qwen3-Next (Hybrid Attention)
│ ├── qwen3_vl.py # Qwen3 VL (Multi-modal)
│ ├── qwen3_vl_moe.py # Qwen3 VL MoE (Multi-modal + MoE)
│ ├── qwen2_vl.py # Qwen2 VL
│ ├── qwen2_5_vl.py # Qwen2.5 VL
│ ├── internlm2.py # InternLM2 Reference Implementation
│ ├── internvl.py # InternVL (Multi-modal)
│ ├── interns1.py # InternS1
│ ├── seed_oss.py # SeedOss
│ ├── gpt_oss.py # GptOss
│ └── mimo_v2_flash.py # MiMo-V2-Flash
├── ops/ # Kunlun Custom Operators
│ ├── _kunlun_ops.py # KunlunOps: paged_attention, rms_norm, silu...
│ ├── _custom_ops.py # vllm custom_op registration
│ ├── activation.py # Activation functions like SiluAndMul, GeluAndMul
│ ├── attention/ # Attention Operators
│ │ ├── layer.py # Attention Layer Wrapper
│ │ └── backends/kunlun_attn.py # KunlunAttentionBackend + KunlunAttentionImpl
│ ├── quantization/ # Quantization related: AWQ, GPTQ, CompressedTensors...
│ ├── vocab_parallel_embedding.py # Custom Embedding
│ └── rotary_embedding.py # Split_Norm_Rope (QKNorm + RoPE Fusion)
├── v1/attention/backends/ # Attention Backend for v1 Engine
│ ├── kunlun_attn.py # Standard Attention
│ └── mla/ # MLA (Multi-Latent Attention) Implementation
│ ├── flashmla.py
│ ├── flashmla_sparse.py
│ └── common.py
├── compilation/wrapper.py # torch.compile Wrapper
├── config/ # Model Configuration Overrides
│ └── model.py # Patch for attributes like is_deepseek_mla
├── distributed/ # Communication related
│ └── kunlun_communicator.py # Kunlun Device Communication
└── csrc/ # C++ Extensions
└── utils.cpp
```
---
## II. New Model Adaptation Step-by-Step
### Step 0: Pre-assessment
Before starting, confirm which scenario your model falls into:
| Scenario | Description | Effort |
|------|------|--------|
| **Case A: vLLM already supports the model** | Only need to replace Attention / Activation with Kunlun versions | ⭐ Minimal |
| **Case B: vLLM does not support, new architecture needed** | Requires full implementation of model class + registration | ⭐⭐⭐ High |
| **Case C: MoE variant of an existing model** | Add MoE layer on top of the Dense version | ⭐⭐ Medium |
| **Case D: Multi-modal model** | Language Model + Vision Encoder + Projector | ⭐⭐⭐⭐ Maximum |
**Recommended Workflow:**
1. Check the [vLLM Supported Models List](https://docs.vllm.ai/en/stable/models/supported_models.html) to see if the model is already there.
2. If yes → Copy the corresponding file from `vllm/model_executor/models/` to `vllm_kunlun/models/` and perform replacements.
3. If no → Refer to the [vLLM Adding a New Model Documentation](https://docs.vllm.ai/en/stable/contributing/model/) to understand the principles first, then follow this manual.
---
### Step 1: Implement Model Files
Create a model file in the `vllm_kunlun/models/` directory, e.g., `my_new_model.py`.
#### 1.1 Key Replacement Comparison Table
| Component | vLLM Native Import | vLLM-Kunlun Replacement Import | Required? |
|------|-----------------|------------------------|---------|
| **Attention Layer** | `from vllm.attention import Attention` | `from vllm_kunlun.ops.attention.layer import Attention` | ✅ **Yes** |
| **SiluAndMul** | `from vllm.model_executor.layers.activation import SiluAndMul` | `from vllm_kunlun.ops.activation import SiluAndMul` | ✅ **Yes** |
| **GeluAndMul** | `...activation import GeluAndMul` | `from vllm_kunlun.ops.activation import GeluAndMul` | ⚠️ As needed |
| **QuickGELU** | `...activation import QuickGELU` | `from vllm_kunlun.ops.activation import QuickGELU` | ⚠️ As needed |
| **VocabParallelEmbedding** | `from vllm...vocab_parallel_embedding import VocabParallelEmbedding` | `from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding` | ⚠️ Some models |
| **ParallelLMHead** | Same as above | `from vllm_kunlun.ops.vocab_parallel_embedding import ParallelLMHead` | ⚠️ Some models |
| **RoPE (Special)** | `from vllm...rotary_embedding import get_rope` | `from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope` | ⚠️ MoE+QKNorm |
| **Linear / RMSNorm, etc.** | Use vLLM native directly | **No replacement needed** | — |
> 💡 **Core Principle**: Any component involving **CUDA kernel calls** (Attention, Activation, Sampling) must be replaced with the Kunlun version; pure PyTorch components (Linear, RMSNorm, RoPE, etc.) can use vLLM native directly.
#### 1.2 Standard Dense Decoder-Only Model Template
Refer to `qwen3.py` or `internlm2.py`:
```python
"""Inference-only MyNewModel compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import MyNewModelConfig # HuggingFace config
# ==========================================
# ⭐ Key Replacement 1: Use Kunlun-customized Attention
# ==========================================
# Do not use from vllm.attention import Attention
from vllm_kunlun.ops.attention.layer import Attention
# ==========================================
# ⭐ Key Replacement 2: Use Kunlun-customized Activation
# ==========================================
# Do not use from vllm.model_executor.layers.activation import SiluAndMul
from vllm_kunlun.ops.activation import SiluAndMul
# Other layers can use vLLM native directly
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear, RowParallelLinear, MergedColumnParallelLinear
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.model_executor.models.interfaces import SupportsPP, SupportsLoRA
from vllm.model_executor.models.utils import (
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix
)
# ============================
# 1. MLP Layer
# ============================
class MyNewModelMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act,
quant_config=None, prefix=""):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size,
bias=False, quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = SiluAndMul() # ⭐ Use Kunlun version
def forward(self, x):
# Implementation...
```
#### 1.3 Key Implementation Requirements
- **All modules must include the `prefix` parameter**, passed in `__init__()`.
- **`@support_torch_compile` decorator** must be added to the main model class (e.g., `MyNewModel`).
- **`load_weights()` method** must correctly handle weight name mapping (`stacked_params_mapping`).
- **Pipeline Parallelism (PP)** requires using tools like `PPMissingLayer`, `is_pp_missing_parameter`, etc.
---
## Step 2: Register the Model
Add registration code in `vllm_kunlun/models/__init__.py`:
```python
# vllm_kunlun/models/__init__.py
from vllm import ModelRegistry
def register_model():
# ... Existing model registrations ...
# ⭐ Add your new model (using lazy loading string format)
ModelRegistry.register_model(
"MyNewModelForCausalLM", # ← Must match architectures in config.json
"vllm_kunlun.models.my_new_model:MyNewModelForCausalLM" # ← Module path:Class name
)
```
**⚠️ Key Considerations:**
1. The **first parameter** of `register_model()` is the model's `architecture` identifier, which **must exactly match the `"architectures"` field in the HuggingFace model's `config.json`**.
2. Use the **string format** for the module path (`"module:class"`) to implement **lazy loading**, avoiding CUDA initialization conflicts (`RuntimeError: Cannot re-initialize CUDA in forked subprocess`).
3. If the model already exists in vLLM (e.g., `Qwen3ForCausalLM`), the Kunlun version will **overwrite** the original vLLM version upon registration.
---
## Step 3: Verify Registration
### Case A: Overwriting an Existing vLLM Model Architecture
If your model architecture name (e.g., `"Qwen3ForCausalLM"`) already exists in vLLM, vLLM will output the following log during registration:
```
WARNING [...] Model architecture Qwen3ForCausalLM is already registered,
and will be overwritten by the new model class
vllm_kunlun.models.qwen3:Qwen3ForCausalLM.
```
Seeing this log indicates a successful overwrite ✅.
### Case B: Brand New Model Architecture
If you are registering an architecture that does not exist in vLLM, there is no default log confirmation. It is recommended to verify manually during the debugging phase:
```python
from vllm import ModelRegistry
assert "MyNewModelForCausalLM" in ModelRegistry.get_supported_archs()
print("✅ Model registration successful!")
```
---
## Step 4: Testing
### 4.1 Offline Inference Test
```python
from vllm import LLM, SamplingParams
llm = LLM(
model="/path/to/MyNewModel",
trust_remote_code=True,
dtype="float16",
tensor_parallel_size=1, # Verify with single card first
)
outputs = llm.generate(
["Hello, please introduce yourself."],
SamplingParams(temperature=0.7, max_tokens=256),
)
for output in outputs:
print(output.outputs[0].text)
```
#### 4.2 Online Service Test
```bash
XPU_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 --port 8888 \
--model /path/to/MyNewModel \
--trust-remote-code \
--dtype float16 \
--max-model-len 4096 \
--block-size 64
```
#### 4.3 Accuracy Verification
It is recommended to compare results with HuggingFace Transformers CPU/GPU inference:
```python
# Transformers reference output
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("/path/to/MyNewModel", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("/path/to/MyNewModel")
# ... Generate and compare output
```
---
## III. Adaptation Guide for Special Model Types
### 3.1 MoE Models (e.g., Qwen3-MoE, DeepSeek-V3)
**Reference Files:**
- `vllm_kunlun/models/qwen3_moe.py`
- `vllm_kunlun/models/deepseek_v2.py`
**Additional Points:**
- Use `vllm.model_executor.layers.fused_moe.layer.FusedMoE`; Kunlun has replaced the underlying kernel via import hook.
- MoE's `load_weights()` is more complex, requiring expert parameter mapping:
```python
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
)
```
- Recommended environment variables:
```bash
export KUNLUN_USE_MOE_FFN_BLOCK=True
export XPU_USE_MOE_SORTED_THRES=120
```
### 3.2 MLA (Multi-Latent Attention) Models (e.g., DeepSeek-V3)
**Reference File:** `vllm_kunlun/models/deepseek_v2.py`
**MLA Special Handling:**
- KV compression dimensions: `kv_lora_rank`, `qk_nope_head_dim`, `qk_rope_head_dim`.
- Platform layer automatically selects `FlashMLABackend`:
```python
# vllm_kunlun/platforms/kunlun.py
if use_mla:
if use_sparse:
return "vllm_kunlun.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
return "vllm_kunlun.v1.attention.backends.mla.flashmla.FlashMLABackend"
```
- `block_size` usually needs to be set to **64**.
- Recommended setting: `export USE_ORI_ROPE=1`.
### 3.3 Multi-modal Models (e.g., Qwen2-VL, InternVL)
**Reference Files:**
- `vllm_kunlun/models/qwen3_vl.py`
- `vllm_kunlun/models/internvl.py`
- `vllm_kunlun/models/interns1.py`
**Additional Components to Implement:**
| Component | Description |
|------|------|
| `SupportsMultiModal` Interface | Declares that the model supports multi-modal input |
| Vision Encoder | Usually `InternVisionModel` or custom ViT |
| Projector | Vision → Language mapping (e.g., MLP) |
| `@MULTIMODAL_REGISTRY.register_processor(...)` | Register multi-modal processor |
| `BaseMultiModalProcessor` | Handles multi-modal input |
| `BaseProcessingInfo` | Handles processing info |
| `BaseDummyInputsBuilder` | Dummy inputs for the profiling phase |
### 3.4 Hybrid Attention Models (e.g., Qwen3-Next)
**Reference File:** `vllm_kunlun/models/qwen3_next.py`
This model contains both **Linear Attention** and **Full Attention** layer types:
```python
# Select different attention calculations based on layer_type
if self.layer_type == "linear_attention":
self.linear_attn(hidden_states=hidden_states, output=self_attention_output)
elif self.layer_type == "full_attention":
self.self_attn(hidden_states=hidden_states, output=self_attention_output, positions=positions)
```
Note:
- Linear Attention uses `GatedDeltaNet` or similar implementations.
- Need to register custom `custom_op` (e.g., `vllm.gdn_attention`) for `splitting_ops` in `torch.compile`.
---
## IV. Quantized Model Adaptation
### 4.1 Supported Quantization Methods
| Quantization Method | Adaptation File | Status |
|---------|---------|------|
| **INT8 Dynamic (W8A8)** | `ops/quantization/kernels/kunlun_scale_mm.py` | ✅ Recommended |
| **AWQ (INT4)** | `ops/quantization/awq.py` | ✅ Supported |
| **GPTQ (INT4)** | `ops/quantization/gptq.py` | ✅ Supported |
| **CompressedTensors (INT8 MoE)** | `ops/quantization/compressed_tensors/` | ✅ Supported |
| **FP8** | — | ⚠️ Partial Support |
| **bfloat16** | — | ⚠️ Double VRAM bug |
### 4.2 Special Handling for Quantization
Kunlun chips use the **max value** for scale calculation instead of vLLM's default absmax:
```python
# ops/quantization/kernels/kunlun_scale_mm.py
class KunlunScaledMMLinearKernel(CutlassScaledMMLinearKernel):
def process_weights_after_loading(self, layer):
super().process_weights_after_loading(layer)
# ⭐ Key: Multiply scale by 127.0 to convert to max format
with torch.no_grad():
getattr(layer, self.w_s_name).mul_(127.0)
```
INT4 weights need to be **repacked** into the Kunlun layout order:
```python
# AWQ repack example
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL]
```
---
## V. Custom Operators (if new low-level Ops are needed)
If your model requires new low-level operators:
### 5.1 Wrap kunlun_ops calls in `_kunlun_ops.py`
```python
# vllm_kunlun/ops/_kunlun_ops.py
class KunlunOps:
@staticmethod
def my_new_op(input, weight, out):
"""Call underlying kunlun_ops implementation"""
kunlun_ops.my_new_op(input, weight, out=out)
```
### 5.2 Register to vLLM in `_custom_ops.py`
Follow the **three-piece pattern**:
```python
# vllm_kunlun/ops/_custom_ops.py
# 1. Define the actual implementation of the op
def my_new_op_impl(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(input)
KunlunOps.my_new_op(input, weight, output)
return output
# 2. Define fake tensor implementation (for torch.compile)
def my_new_op_fake(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input)
# 3. Register
direct_register_custom_op(
op_name="my_new_op",
op_func=my_new_op_impl,
mutates_args=[],
fake_impl=my_new_op_fake,
)
```
---
## VI. Common Pitfalls Checklist
Before submitting a PR, please check each item:
- [ ] **Attention** uses `vllm_kunlun.ops.attention.layer.Attention`?
- [ ] **Activation functions** use `vllm_kunlun.ops.activation.SiluAndMul`, etc.?
- [ ] All submodules in `__init__()` have the `prefix` parameter passed?
- [ ] `load_weights()` correctly handles weight name mapping (`stacked_params_mapping`)?
- [ ] `@support_torch_compile` decorator is added to the main model class?
- [ ] The first parameter of `ModelRegistry.register_model()` exactly matches `architectures` in `config.json`?
- [ ] No use of `VLLM_USE_V1` environment variable for logic (deprecated, v0.15.1 is V1-only)?
- [ ] Type annotations use `Optional[T]` instead of `T | None` (to avoid `infer_schema` failure)?
- [ ] Quantized model scales are correctly multiplied by `127.0`?
- [ ] Supports Pipeline Parallelism (using `PPMissingLayer`, `is_pp_missing_parameter`)?
- [ ] Ran `pre-commit` format checks?
- [ ] Commits use `-s` signature (DCO compliance)?
---
## VII. Reference Template Quick Look-up
| Model Type | Best Reference File | Features |
|---------|------------|------|
| Standard Dense LLM | `qwen3.py` | Simplest, recommended for beginners |
| Dense LLM (Custom Embedding) | `seed_oss.py`, `internlm2.py` | Custom VocabParallelEmbedding |
| MoE LLM | `qwen3_moe.py` | FusedMoE + EP + SharedExpert |
| MLA + MoE (DeepSeek) | `deepseek_v2.py` | MLA attention + MoE + Indexer |
| Hybrid Attention | `qwen3_next.py` | Linear + Full attention |
| Multi-modal (VL) | `qwen3_vl.py`, `internvl.py` | ViT + Projector + LLM |
| Speculative Decoding (MTP) | `deepseek_mtp.py` | Multi-Token Prediction |
---
## VIII. Debugging Tips
### 8.1 Startup Failure
- **`ModuleNotFoundError`**: Check if the import hook mapping table in `__init__.py` covers the corresponding module.
- **`circular import`**: Check if your new code introduces heavy dependencies during the `register()` phase.
- **`Model architecture XXX is not supported`**: Check if the first parameter of `register_model()` matches `config.json`.
### 8.2 Abnormal Output
- **Garbage output**: Compare with HF transformers output on CPU; likely an operator precision issue or weight loading mapping error.
- **Repeated tokens**: Check if `rotary_embedding` is applied correctly and if the `is_neox_style` parameter is correct.
- **Truncated output**: Check `max_model_len` settings and if KV cache is sufficient.
### 8.3 VRAM Issues
- Use `--dtype float16` (avoid bfloat16 due to double VRAM bug).
- Set `VLLM_KUNLUN_ENABLE_INT8_BMM=1` (saves ~0.1GB).
- Lower `--gpu-memory-utilization` (default is 0.9).
- Use INT8 quantized models.
### 8.4 Weight Loading Failure
```python
# Debugging method: Print parameter names for comparison
params_dict = dict(self.named_parameters())
print("=== Model params ===")
for k in sorted(params_dict.keys()):
print(f" {k}: {params_dict[k].shape}")
# Print in load_weights
for name, loaded_weight in weights:
if name not in params_dict:
print(f" ⚠️ Skipped: {name}")
```
### 8.5 Kunlun Graph Failure
Confirm that `splitting_ops` in `compilation-config` includes your attention op name:
```json
{
"splitting_ops": [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_attention_with_output_kunlun",
"vllm.sparse_attn_indexer_vllm_kunlun"
],
"cudagraph_mode": "PIECEWISE"
}
```
---
## IX. Environment Variables Cheat Sheet
```bash
# === Required ===
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # Specify Kunlun cards to use
export VLLM_HOST_IP=$(hostname -i) # IP for distributed communication
# === Recommended ===
export XMLIR_FORCE_USE_XPU_GRAPH=1 # Enable XPU Graph acceleration
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false # Disable mock compile
export XMLIR_CUDNN_ENABLED=1 # Enable cuDNN equivalent acceleration
export XPU_USE_DEFAULT_CTX=1 # Default context
export BKCL_FORCE_SYNC=1 # BKCL forced sync (multi-card stability)
# === Model Specific ===
export USE_ORI_ROPE=1 # DeepSeek series uses original RoPE
export XFT_USE_FAST_SWIGLU=1 # Fast SwiGLU activation
export XPU_USE_FAST_SWIGLU=1 # Same as above (some versions)
export XPU_USE_MOE_SORTED_THRES=120 # MoE sorting threshold
export KUNLUN_USE_MOE_FFN_BLOCK=True # MoE FFN block optimization
# === Optional Tuning ===
export VLLM_KUNLUN_ENABLE_INT8_BMM=1 # Enable INT8 BMM (saves ~0.1GB)
```
---
## X. PR Submission Standards
### 10.1 Branch Naming
```
feature/add-my-new-model
bugfix/fix-attention-output
```
### 10.2 Commit Message Prefix
| Prefix | Description |
|------|------|
| `[Feature]` | New functionality / New model |
| `[Bugfix]` | Bug fix |
| `[CI/Build]` | CI / Build related |
| `[Doc]` | Documentation update |
| `[Misc]` | Others |
### 10.3 Before Submission
```bash
# 1. Install pre-commit
pre-commit install
# 2. Run checks
pre-commit run --all-files
# 3. Signed commit (DCO compliance)
git commit -s -m "[Feature] Add MyNewModel support for Kunlun"
```
### 10.4 PR Checklist
- [ ] Code passes `pre-commit` checks.
- [ ] Single-card offline inference test passed.
- [ ] Multi-card TP test passed (if applicable).
- [ ] Quantized model test passed (if applicable).
- [ ] Updated `vllm_kunlun/models/__init__.py` registration.
- [ ] Updated supported models list in README (if applicable).
---
## Appendix: Standard Startup Command Templates
### A. Standard Dense Model (Single Card)
```bash
XPU_VISIBLE_DEVICES=0 \
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 --port 8888 \
--model /path/to/model \
--trust-remote-code \
--dtype float16 \
--max-model-len 8192 \
--block-size 64
```
### B. MoE Model (8-card TP)
```bash
XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
XMLIR_FORCE_USE_XPU_GRAPH=1 \
KUNLUN_USE_MOE_FFN_BLOCK=True \
XPU_USE_MOE_SORTED_THRES=120 \
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 --port 8888 \
--model /path/to/moe-model-int8 \
--trust-remote-code \
--dtype float16 \
--max-model-len 32768 \
--tensor-parallel-size 8 \
--max_num_seqs 4 \
--block-size 64 \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--no-enable-prefix-caching
```
### C. DeepSeek-V3 (MLA + MoE, W8A8)
```bash
XMLIR_ENABLE_MOCK_TORCH_COMPILE=false \
USE_ORI_ROPE=1 \
XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 --port 8806 \
--model /path/to/DeepSeek-V3-w8a8 \
--gpu-memory-utilization 0.98 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 8 \
--dtype float16 \
--max_num_seqs 4 \
--block-size 64 \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--no-enable-prefix-caching
```
---
> 📝 **Document Maintenance**: If you have questions or suggestions, please provide feedback in [GitHub Issues](https://github.com/baidu/vLLM-Kunlun/issues).

View File

@@ -6,17 +6,23 @@ torch_xray is an operator precision analysis tool that can dump module-level inp
### 1.Download and install
***\*python3.10:\****
**\*python3.12:\***
bos:/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl
```
pip install "https://klx-sdk-release-public.su.bcebos.com/torch_xray/release/2.0.3.0/torch_xray-2.0.3-cp312-cp312-linux_x86_64.whl"
```
[https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/](https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-py3-none-any.whl)torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl
**\*python3.10:\***
***\*python3.8:\****
```
pip install "https://klx-sdk-release-public.su.bcebos.com/torch_xray/release/2.0.3.0/torch_xray-2.0.3-cp310-cp310-linux_x86_64.whl"
```
bos:/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-cp38-cp38-linux_x86_64.whl
**\*python3.8:\***
[https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/](https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-py3-none-any.whl)torch_xray-999.9.9-cp38-cp38-linux_x86_64.whl
```
pip install "https://klx-sdk-release-public.su.bcebos.com/torch_xray/release/2.0.3.0/torch_xray-2.0.3-cp38-cp38-linux_x86_64.whl"
```
Note that the same installation package must be used when using it in different environments.

View File

@@ -52,6 +52,7 @@ user_guide/release_notes
:::{toctree}
:caption: Developer Guide
:maxdepth: 1
developer_guide/developer_guide
developer_guide/contribution/index
developer_guide/feature_guide/index
developer_guide/evaluation/index

View File

@@ -75,55 +75,34 @@ cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/vllm_kunlun_0.10.1.1/l
## Choose to download customized xpytorch
### Install the KL3-customized build of PyTorch
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-02T05%3A01%3A27Z%2F-1%2Fhost%2Ff3cf499234f82303891aed2bcb0628918e379a21e841a3fac6bd94afef491ff7
(for the conda)
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/baidu-kunlun-share/20260206/xpytorch-cp310-torch251-ubuntu2004-x64.run
#for conda
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the uv)
#for uv
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
```
### Install the KL3-customized build of PyTorch (Only MIMO V2)
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1231/xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the conda)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the uv)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
```
### Install the KL3-customized build of PyTorch (Only DeepSeek-V3.2-Exp-w8a8)
```
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://aihc-private-hcd.bj.bcebos.com/v1/vllm-kunlun-ds/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2026-02-03T01%3A59%3A40Z%2F-1%2Fhost%2Ffc4b6f5b83c2fde70d48fdfc23c40c396efc9cb3c36d6f811fdca5f109073321
(for the conda)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
(for the uv)
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
mv torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl torch_xray-2.0.3-cp310-cp310-linux_x86_64.whl && \
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g; s/torch_xray-999.9.9/torch_xray-2.0.3/' setup.sh && bash setup.sh
```
## Choose to download customized ops
### Install custom ops
```
uv pip install "https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xtorch_ops-0.1.2209%2B6752ad20-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-05T06%3A18%3A00Z%2F-1%2Fhost%2F14936c2b7e7c557c1400e4c467c79f7a9217374a7aa4a046711ac4d948f460cd"
```
### Install custom ops (Only MIMO V2)
```
uv pip install "https://vllm-ai-models.bj.bcebos.com/v1/vLLM-Kunlun/ops/swa/xtorch_ops-0.1.2109%252B523cb26d-cp310-cp310-linux_x86_64.whl"
```
### Install custom ops (Only DeepSeek-V3.2-Exp-w8a8)
```
uv pip install "https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1215/xtorch_ops-0.1.2263%2Bc030eebd-cp310-cp310-linux_x86_64.whl"
uv pip install "https://baidu-kunlun-public.su.bcebos.com/baidu-kunlun-share/20260206/kunlun_ops-0.1.45%2Bbac5499e-cp310-cp310-linux_x86_64.whl"
```
## Install the KLX3 custom Triton build
```
uv pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.0.0%2Bb2cde523-cp310-cp310-linux_x86_64.whl"
```
## Install the AIAK custom ops library
```
uv pip install "https://vllm-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20260130_152557/xspeedgate_ops-0.0.0%2Be5cdcbe-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKhvtgrTA8US5LIc8Vbl0mP%2F2026-01-30T10%3A33%3A32Z%2F2592000%2Fhost%2F3c13d67cc61d0df7538c198f5c32422f3b034068a40eef43cb51b079cc6f0555" --force-reinstall
```

View File

@@ -8,5 +8,7 @@ single_xpu_Qwen3-VL-32B
single_xpu_InternVL2_5-26B
multi_xpu_Qwen2.5-VL-32B
multi_xpu_GLM-4.5
multi_xpu_GLM-5-W8A8-INT8
multi_xpu_DeepSeek-V3.2-Exp-w8a8
multi_xpu_Qwen3-Coder-480B-A35B(W8A8)
:::

View File

@@ -7,6 +7,7 @@ Setup environment using container:
Please follow the [installation.md](../installation.md) document to set up the environment first.
Create a container
```bash
# !/bin/bash
# rundocker.sh
@@ -36,13 +37,16 @@ docker run -itd ${DOCKER_DEVICE_CONFIG} \
### Preparation Weight
- Pull DeepSeek-V3.2-Exp-w8a8-int8 weights
```
wget -O DeepSeek-V3.2-Exp-w8a8-int8.tar.gz https://aihc-private-hcd.bj.bcebos.com/v1/LLM/DeepSeek/DeepSeek-V3.2-Exp-w8a8-int8.tar.gz?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2025-12-24T06%3A07%3A10Z%2F-1%2Fhost%2Fa324bf469176934a05f75d3acabc3c1fb891be150f43fb1976e65b7ec68733db
```
- Ensure that the field "quantization_config" is included.If not, deployment will result in an OOM (Out of Memory) error.
vim model/DeepSeek-V3.2-Exp-w8a8-int8/config.json
```config.json
```json
"quantization_config": {
"config_groups": {
"group_0": {
@@ -108,7 +112,7 @@ export CUDA_GRAPH_OPTIMIZE_STREAM=1 && \
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false && \
export XPU_USE_MOE_SORTED_THRES=1 && \
export USE_ORI_ROPE=1 && \
export VLLM_USE_V1=1
export VLLM_USE_V1=1
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
@@ -129,9 +133,9 @@ python -m vllm.entrypoints.openai.api_server \
--compilation-config '{"splitting_ops":["vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_attention_with_output_kunlun",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",

View File

@@ -0,0 +1,92 @@
# Multi XPU (GLM-5-W8A8-INT8)
## Run vllm-kunlun on Multi XPU
Setup environment using container:
Please follow the [installation.md](../installation.md) document to set up the environment first.
Create a container
```bash
# !/bin/bash
# rundocker.sh
XPU_NUM=8
DOCKER_DEVICE_CONFIG=""
if [ $XPU_NUM -gt 0 ]; then
for idx in $(seq 0 $((XPU_NUM-1))); do
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
done
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
fi
export build_image="xxx"
docker run -itd ${DOCKER_DEVICE_CONFIG} \
--net=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
--cap-add=SYS_PTRACE \
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
--name "$1" \
-w /workspace \
"$build_image" /bin/bash
```
### Preparation Weight
- Pull GLM-5-W8A8-INT8 weights
```
wget -O GLM-5-W8A8-INT8-Dynamic.tar.gz https://aihc-private-hcd.bj.bcebos.com/LLM/AICapX-Quant-Models/GLM-5-W8A8-INT8-Dynamic.tar.gz
```
### Online Serving on Multi XPU
Start the vLLM server on multi XPU:
```bash
unset XPU_DUMMY_EVENT && \
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && \
export XMLIR_CUDNN_ENABLED=1 && \
export XPU_USE_DEFAULT_CTX=1 && \
export XMLIR_FORCE_USE_XPU_GRAPH=1 && \
export XMLIR_ENABLE_FAST_FC=1 && \
export XPU_USE_FAST_SWIGLU=1 && \
export CUDA_GRAPH_OPTIMIZE_STREAM=1 && \
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false && \
export XPU_USE_MOE_SORTED_THRES=1 && \
export USE_ORI_ROPE=1 && \
export VLLM_USE_V1=1
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 8806 \
--model GLM-5-W8A8-INT8-Dynamic \
--gpu-memory-utilization 0.97 \
--trust-remote-code \
--max-model-len 32768 \
--tensor-parallel-size 8 \
--dtype bfloat16 \
--max_num_seqs 8 \
--max_num_batched_tokens 8192 \
--block-size 64 \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--disable-log-requests \
--no-enable-prefix-caching \
--kv-cache-dtype bfloat16 \
--compilation-config '{
"splitting_ops":[
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_attention_with_output_kunlun",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer",
"vllm.sparse_attn_indexer_vllm_kunlun"
]}'
```

View File

@@ -86,8 +86,10 @@ if __name__ == "__main__":
main()
```
:::::
If you run this script successfully, you can see the info shown below:
```bash
==================================================
Input content: [{'role': 'user', 'content': [{'type': 'text', 'text': '你好!你是谁?'}]}]
@@ -95,9 +97,11 @@ Model response:
你好!我是一个由人工智能驱动的助手,旨在帮助回答问题、提供信息和解决日常问题。请问有什么我可以帮助你的?
==================================================
```
### Online Serving on Single XPU
Start the vLLM server on a single XPU:
```bash
```text
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 9988 \
@@ -114,25 +118,29 @@ python -m vllm.entrypoints.openai.api_server \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name InternVL2_5-26B \
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_attention_with_output_kunlun",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer"]}
#Version 0.11.0
#Version 0.11.0
```
If your service start successfully, you can see the info shown below:
```bash
(APIServer pid=157777) INFO: Started server process [157777]
(APIServer pid=157777) INFO: Waiting for application startup.
(APIServer pid=157777) INFO: Application startup complete.
```
Once your server is started, you can query the model with input prompts:
```bash
curl http://localhost:9988/v1/completions \
-H "Content-Type: application/json" \
@@ -145,17 +153,23 @@ curl http://localhost:9988/v1/completions \
"top_k": 50
}'
```
If you query the server successfully, you can see the info shown below (client):
```bash
{"id":"cmpl-23a24afd616d4a47910aeeccb20921ed","object":"text_completion","created":1768891222,"model":"InternVL2_5-26B","choices":[{"index":0,"text":" 你有什么问题吗?\n\n你好我是书生·AI很高兴能与你交流。请问有什么我可以帮助你的吗无论是解答问题、提供信息还是其他方面的帮助我都会尽力而为。请告诉我你的需求。","logprobs":null,"finish_reason":"stop","stop_reason":92542,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":6,"total_tokens":53,"completion_tokens":47,"prompt_tokens_details":null},"kv_transfer_params":null}
```
Logs of the vllm server:
```bash
(APIServer pid=161632) INFO: 127.0.0.1:56708 - "POST /v1/completions HTTP/1.1" 200 OK
(APIServer pid=161632) INFO 01-20 14:40:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.6 tokens/s, Avg generation throughput: 4.6 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=161632) INFO 01-20 14:40:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
```
Input an image for testing.Here,a python script is used:
```python
import requests
import base64
@@ -193,13 +207,17 @@ payload = {
response = requests.post(API_URL, json=payload)
print(response.json())
```
If you query the server successfully, you can see the info shown below (client):
```bash
{'id': 'chatcmpl-9aeab6044795458da04f2fdcf1d0445d', 'object': 'chat.completion', 'created': 1768891349, 'model': 'InternVL2_5-26B', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': '你好这张图片上有一个黄色的笑脸表情符号双手合十旁边写着“Hugging Face”。这个表情符号看起来很开心似乎在表示拥抱或欢迎。', 'refusal': None, 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [], 'reasoning_content': None}, 'logprobs': None, 'finish_reason': 'stop', 'stop_reason': 92542, 'token_ids': None}], 'service_tier': None, 'system_fingerprint': None, 'usage': {'prompt_tokens': 790, 'total_tokens': 827, 'completion_tokens': 37, 'prompt_tokens_details': None}, 'prompt_logprobs': None, 'prompt_token_ids': None, 'kv_transfer_params': None}
```
Logs of the vllm server:
```bash
(APIServer pid=161632) INFO: 127.0.0.1:58686 - "POST /v1/chat/completions HTTP/1.1" 200 OK
(APIServer pid=161632) INFO 01-20 14:42:35 [loggers.py:127] Engine 000: Avg prompt throughput: 79.0 tokens/s, Avg generation throughput: 3.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=161632) INFO 01-20 14:42:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
```
```

View File

@@ -85,19 +85,23 @@ if __name__ == "__main__":
main()
```
:::::
If you run this script successfully, you can see the info shown below:
```bash
==================================================
Input content: [{'role': 'user', 'content': [{'type': 'text', 'text': 'tell a joke'}]}]
Model response:
Why dont skeletons fight each other?
Why dont skeletons fight each other?
Because they dont have the guts! 🦴😄
==================================================
```
### Online Serving on Single XPU
Start the vLLM server on a single XPU:
```bash
```text
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--port 9988 \
@@ -114,25 +118,29 @@ python -m vllm.entrypoints.openai.api_server \
--no-enable-chunked-prefill \
--distributed-executor-backend mp \
--served-model-name Qwen3-VL-32B \
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_attention_with_output_kunlun",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer"]}
#Version 0.11.0
#Version 0.11.0
```
If your service start successfully, you can see the info shown below:
```bash
(APIServer pid=109442) INFO: Started server process [109442]
(APIServer pid=109442) INFO: Waiting for application startup.
(APIServer pid=109442) INFO: Application startup complete.
```
Once your server is started, you can query the model with input prompts:
```bash
curl http://localhost:9988/v1/completions \
-H "Content-Type: application/json" \
@@ -143,11 +151,15 @@ curl http://localhost:9988/v1/completions \
"temperature": 0
}'
```
If you query the server successfully, you can see the info shown below (client):
```bash
{"id":"cmpl-4f61fe821ff34f23a91baade5de5103e","object":"text_completion","created":1768876583,"model":"Qwen3-VL-32B","choices":[{"index":0,"text":" 你好!我是通义千问,是阿里云研发的超大规模语言模型。我能够回答问题、创作文字、编程等,还能根据你的需求进行多轮对话。有什么我可以帮你的吗?😊\n\n温馨提示我是一个AI助手虽然我尽力提供准确和有用的信息但请记得在做重要决策时最好结合专业意见或进一步核实信息哦","logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":5,"total_tokens":90,"completion_tokens":85,"prompt_tokens_details":null},"kv_transfer_params":null}
```
Logs of the vllm server:
```bash
(APIServer pid=109442) INFO: 127.0.0.1:19962 - "POST /v1/completions HTTP/1.1" 200 OK
(APIServer pid=109442) INFO 01-20 10:36:28 [loggers.py:127] Engine 000: Avg prompt throughput: 0.5 tokens/s, Avg generation throughput: 8.5 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
@@ -155,7 +167,9 @@ Logs of the vllm server:
(APIServer pid=109442) INFO 01-20 10:43:23 [chat_utils.py:560] Detected the chat template content format to be 'openai'. You can set `--chat-template-content-format` to override this.
(APIServer pid=109442) INFO 01-20 10:43:28 [loggers.py:127] Engine 000: Avg prompt throughput: 9.0 tokens/s, Avg generation throughput: 6.9 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.5%, Prefix cache hit rate: 0.0%
```
Input an image for testing.Here,a python script is used:
```python
import requests
import base64
@@ -191,11 +205,15 @@ payload = {
response = requests.post(API_URL, json=payload)
print(response.json())
```
If you query the server successfully, you can see the info shown below (client):
```bash
{'id': 'chatcmpl-4b42fe46f2c84991b0af5d5e1ffad9ba', 'object': 'chat.completion', 'created': 1768877003, 'model': 'Qwen3-VL-32B', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': '你好这张图片展示的是“Hugging Face”的标志。\n\n图片左侧是一个黄色的圆形表情符号emoji它有着圆圆的眼睛、张开的嘴巴露出微笑双手合拢在脸颊两侧做出一个拥抱或欢迎的姿态整体传达出友好、温暖和亲切的感觉。\n\n图片右侧是黑色的英文文字“Hugging Face”字体简洁现代与左侧的表情符号相呼应。\n\n整个标志设计简洁明了背景为纯白色突出了标志本身。这个标志属于Hugging Face公司它是一家知名的开源人工智能公司尤其在自然语言处理NLP领域以提供预训练模型如Transformers库和模型托管平台而闻名。\n\n整体来看这个标志通过可爱的表情符号和直白的文字成功传达了公司“拥抱”技术、开放共享、友好的品牌理念。', 'refusal': None, 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [], 'reasoning_content': None}, 'logprobs': None, 'finish_reason': 'stop', 'stop_reason': None, 'token_ids': None}], 'service_tier': None, 'system_fingerprint': None, 'usage': {'prompt_tokens': 90, 'total_tokens': 266, 'completion_tokens': 176, 'prompt_tokens_details': None}, 'prompt_logprobs': None, 'prompt_token_ids': None, 'kv_transfer_params': None}
```
Logs of the vllm server:
```bash
(APIServer pid=109442) INFO: 127.0.0.1:26854 - "POST /v1/chat/completions HTTP/1.1" 200 OK
(APIServer pid=109442) INFO 01-20 10:43:38 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 10.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

View File

@@ -2,14 +2,36 @@
## Generative Models
| Model | Support | W8A8 | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Piecewise Kunlun Graph |
| :------------ | :------ | :--- | :--- | :-------------- | :-------------- | :------------ | :--------------------- |
| Qwen3 | ✅ || ✅ | | | | |
| Qwen3-Moe | | ✅ | ✅ | | | | |
| Qwen3-Next | ✅ | ✅ | | ✅ | ✅ | | |
| Deepseek v3.2 | | ✅ | | ✅ | | | ✅ |
| Model | Support | INT8(W8A8) | AWQ(W4A16) | GPTQ(WNA16) | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Kunlun Graph |
| :------------ | :-----: | :--------: | :--------: | :---------: | :---: | :-------------: | :-------------: | :-----------: | :----------: |
| Qwen2 | | | | | ✅ | | | | |
| Qwen2.5 | | ✅ | | | | | | | |
| Qwen3 | ✅ | | | | ✅ | | | | |
| Qwen3-Moe | ✅ | ✅ | | | | | | | |
| Qwen3-Next | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
| MiMo-V2-Flash | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| Llama2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
| Llama3 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
| Llama3.1 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| gpt-oss | ✅ | ✅ | ✅ | ✅ | | ✅ | | | |
| GLM4.5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| GLM4.5Air | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| GLM4.7 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| GLM5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| Kimi-K2 | ✅ | - | ✅ | - | | ✅ | | ✅ | ✅ |
| DeepSeek-R1 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| DeepSeek-V3 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| DeepSeek-V3.2 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
## Multimodal Language Models
| Model | Support | W8A8 | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Piecewise Kunlun Graph |
| :------- | :------ | :--- | :--- | :-------------- | :-------------- | :------------ | :--------------------- |
| Qwen3-VL | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| Model | Support | INT8(W8A8) | AWQ(W4A16) | GPTQ(WNA16) | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Kunlun Graph |
| :------------- | :-----: | :--------: | :--------: | :---------: | :---: | :-------------: | :-------------: | :-----------: | :----------: |
| Qwen2-VL | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
| Qwen2.5-VL | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
| Qwen3-VL | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
| Qwen3-VL-MoE | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
| Qwen3-Omni-MoE | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
| InternVL-2.5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| InternVL-3.5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
| InternS1 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |

View File

@@ -1,5 +1,5 @@
unset XPU_DUMMY_EVENT
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export XFT_USE_FAST_SWIGLU=1 #使用快速swiglu实现
export XPU_USE_FAST_SWIGLU=1 #使用moe算子中快速swiglu实现
export XMLIR_CUDNN_ENABLED=1

View File

@@ -56,6 +56,16 @@ def register():
"""Register the Kunlun platform"""
from .utils import redirect_output
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
# Change for GLM5
if "vllm.transformers_utils.config" in sys.modules:
from .transformer_utils.config import _XPU_CONFIG_REGISTRY
sys.modules["vllm.transformers_utils.config"]._CONFIG_REGISTRY = _XPU_CONFIG_REGISTRY
import vllm.config.model as model_module
from .config.model import is_deepseek_mla
model_module.ModelConfig.is_deepseek_mla = property(is_deepseek_mla)
import_hook()
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"

View File

View File

@@ -0,0 +1,22 @@
def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in (
"deepseek_v2",
"deepseek_v3",
"deepseek_v32",
"deepseek_mtp",
"kimi_k2",
"longcat_flash",
"glm_moe_dsa",
):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == "eagle":
# if the model is an EAGLE module, check for the
# underlying architecture
return (
self.hf_text_config.model.model_type
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
and self.hf_text_config.kv_lora_rank is not None
)
return False

View File

@@ -3,91 +3,113 @@ from vllm import ModelRegistry
def register_model():
# from .demo_model import DemoModel # noqa: F401
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
from .qwen3_vl import Qwen3VLForConditionalGeneration
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401
from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401
from .qwen3_omni_moe_thinker import ( # noqa: F401
Qwen3OmniMoeThinkerForConditionalGeneration,
)
from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401
# from .llama4 import Llama4ForCausalLM #noqa: F401
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
# from .deepseek_v2 import KunlunDeepseekV2MoE
# ModelRegistry.register_model(
# "DemoModel",
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration")
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration",
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration")
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3ForCausalLM",
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM")
"Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM"
)
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
"Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM"
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
"Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM"
)
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
"Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP"
)
ModelRegistry.register_model(
"InternLM2ForCausalLM",
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
"GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM"
)
ModelRegistry.register_model(
"InternVLChatModel",
"vllm_kunlun.models.internvl:InternVLChatModel")
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
)
ModelRegistry.register_model(
"InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM"
)
ModelRegistry.register_model(
"InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel"
)
ModelRegistry.register_model(
"InternS1ForConditionalGeneration",
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3VLForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3VLMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration",
)
ModelRegistry.register_model(
"Qwen3OmniMoeForConditionalGeneration",
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration",
)
ModelRegistry.register_model(
"SeedOssForCausalLM",
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
"SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM"
)
ModelRegistry.register_model(
"MiMoV2FlashForCausalLM",
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM",
)
ModelRegistry.register_model(
"GptOssForCausalLM",
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
"DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
"DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_kunlun.models.deepseek_mtp:DeepSeekMTP")
"DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP"
)
ModelRegistry.register_model(
"GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM"
)
def register_quant_method():
"""to do"""

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm
logger = init_logger(__name__)
KVCache = tuple[torch.Tensor, torch.Tensor]
@support_torch_compile
class Qwen3NextMultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fc",
)
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(self.num_mtp_layers)
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
hidden_states = self.pre_fc_norm_hidden(hidden_states)
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
hidden_states = self.fc(hidden_states)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
current_step_idx = spec_step_idx % self.num_mtp_layers
hidden_states, residual = self.layers[current_step_idx](
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class Qwen3NextMTP(nn.Module, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert (
not cache_config.enable_prefix_caching
), "Qwen3NextMTP currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.model = Qwen3NextMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
)
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
hidden_states = self.model(
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> Optional[torch.Tensor]:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
shared_weight_names = ["embed_tokens", "lm_head"]
def remap_weight_names(weights):
for name, weight in weights:
if name.startswith("mtp."):
name = name.replace("mtp.", "model.")
elif not any(key in name for key in shared_weight_names):
continue
yield name, weight
loader = AutoWeightsLoader(self)
return loader.load_weights(remap_weight_names(weights))

View File

@@ -85,7 +85,7 @@ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings)
from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
import xtorch_ops
import kunlun_ops
from einops import repeat
logger = init_logger(__name__)

View File

@@ -16,33 +16,34 @@
# limitations under the License.
"""kunlun custom op entry"""
import torch_xmlir
from typing import Optional
import torch
import os
from typing import Optional, List, Dict
import vllm.envs as envs
import os
import ctypes
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import xtorch_ops
logger.info(f"Load custom ops library success!")
import cocopod # noqa
import kunlun_ops
logger.info("Load custom ops library success!")
except ImportError as e:
logger.warning("Import error msg: %s", e.msg)
_per_token_smooth_quant = True
def is_per_token_smooth_quant():
""" is per token smooth quant """
"""is per token smooth quant"""
return _per_token_smooth_quant
class KunlunOps:
"""KunlunOps"""
# Attention ops
@staticmethod
def paged_attention_v1(
@@ -67,11 +68,11 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False
):
""" PagedAttentionV1 """
alibi_sqrt=False,
):
"""PagedAttentionV1"""
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
@@ -81,7 +82,7 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128
vo_head_dim=128,
)
@staticmethod
@@ -110,11 +111,11 @@ class KunlunOps:
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
alibi_sqrt=False
):
""" PagedAttentionV2 """
alibi_sqrt=False,
):
"""PagedAttentionV2"""
# block_size = value_cache.shape[2]
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=query,
k_cache=key_cache,
v_cache=value_cache,
@@ -124,31 +125,28 @@ class KunlunOps:
is_context=is_context,
is_causal=True,
out=output,
vo_head_dim=128
vo_head_dim=128,
)
# Activation ops
@staticmethod
def silu_and_mul(out: torch.Tensor,
x: torch.Tensor):
""" silu and mul """
xtorch_ops.silu_and_mul(
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
"""silu and mul"""
kunlun_ops.silu_and_mul(
x,
axis=-1,
turn=True,
out=out,
)
)
# Activation ops
@staticmethod
def quick_gelu(out: torch.Tensor,
x: torch.Tensor):
""" quick gelu """
xtorch_ops.quick_gelu(
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
"""quick gelu"""
kunlun_ops.quick_gelu(
x,
out=out,
)
)
# Layernorm
@staticmethod
@@ -159,9 +157,7 @@ class KunlunOps:
epsilon,
):
"""rms_norm"""
xtorch_ops.rmsnorm(
x, weight.to(torch.float32), epsilon, out=out
)
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
@staticmethod
def fused_add_rms_norm(
@@ -172,97 +168,61 @@ class KunlunOps:
):
"""fused_add_rms_norm"""
output = torch.empty_like(x)
xtorch_ops.add_rmsnorm(
kunlun_ops.add_rmsnorm(
x, residual, weight.to(torch.float32), epsilon, out=output
)
fused_input = x + residual
residual.copy_(fused_input, non_blocking=True)
x.copy_(output)
# Rotary embedding
@staticmethod
def rotary_embedding(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
positions, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
num_tokens = query_x.shape[0]
num_heads = query_x.shape[1] // head_size
num_kv_heads = key_x.shape[1] // head_size
torch.ops._C.rotary_embedding(
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
is_neox_style)
query_x = query_x.view(num_tokens, num_heads * head_size)
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
)
return query_x, key_x
# Rotary embedding
@staticmethod
def mrotary_embedding(
positions,
mrope_section,
query,
key,
head_size,
cos_sin_cache,
is_neox_style):
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
):
"""
refactor RotaryEmbedding forward function
"""
query_x = query.contiguous()
key_x = key.contiguous()
query_x_dim = query_x.dim()
assert is_neox_style
xtorch_ops.mrotary_embedding_neox(
positions,
query_x,
key_x,
head_size,
cos_sin_cache,
mrope_section)
kunlun_ops.mrotary_embedding_neox(
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
)
query.data = query_x
key.data = key_x
key.data = key_x
return query, key
@staticmethod
def swap_blocks(
src,
dst,
block_mapping):
""" swap_blocks """
xtorch_ops.swap_blocks(
src,
dst,
block_mapping
)
def swap_blocks(src, dst, block_mapping):
"""swap_blocks"""
kunlun_ops.swap_blocks(src, dst, block_mapping)
@staticmethod
def copy_blocks(
key_caches,
value_caches,
block_mapping):
""" copy_blocks """
def copy_blocks(key_caches, value_caches, block_mapping):
"""copy_blocks"""
for i in range(len(key_caches)):
key_caches[i] = key_caches[i].contiguous()
value_caches[i] = value_caches[i].contiguous()
xtorch_ops.copy_blocks(
kunlun_ops.copy_blocks(
key_caches,
value_caches,
block_mapping,
@@ -276,16 +236,10 @@ class KunlunOps:
value_cache,
slot_mapping,
kv_cache_dtype,
):
""" reshape_and_cache """
):
"""reshape_and_cache"""
# slot_mapping_cast = slot_mapping.to(torch.int32)
xtorch_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping
)
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
@staticmethod
def multi_query_kv_attention(
@@ -294,7 +248,7 @@ class KunlunOps:
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
**kargs
**kargs,
) -> torch.Tensor:
"""
query: shape = [num_prompt_tokens, num_heads, head_size]
@@ -304,18 +258,14 @@ class KunlunOps:
key = key.unsqueeze(0)
value = value.unsqueeze(0)
output = torch.empty_like(query)
alibi_slopes = kargs.get("alibi_slopes", None)
mask = kargs.get("mask", None)
is_causal = kargs.get("is_causal", True)
is_lvsl = kargs.get("is_lvsl", True)
B, T, Qh, Hd = query.shape
KVh = key.size(2)
if KVh != Qh:
repeat = Qh // KVh
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
value = value.repeat_interleave(repeat, dim=2)
xtorch_ops.attention(
kunlun_ops.attention(
q=query,
k_cache=key,
v_cache=value,
@@ -328,80 +278,90 @@ class KunlunOps:
return output
@staticmethod
def quant_fusedresidual_rmsnorm_op(x,
residual,
weight,
bias,
scale_to_int,
eps,
dyn_scale: bool,
type: int = 1):
def quant_fusedresidual_rmsnorm_op(
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized fused residual layer normalization"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
out=out, out_scale=out_scale , residual_tensor=residual)
kunlun_ops.quant_fusedresidual_rmsnorm(
x,
residual,
weight,
bias,
eps,
out=out,
out_scale=out_scale,
residual_tensor=residual,
)
if residual is None:
return out, out_scale
return out, out_scale, residual
@staticmethod
def quant_rmsnorm_op(x,
weight,
bias,
scale_to_int,
eps,
dyn_scale : bool,
type: int = 1):
def quant_rmsnorm_op(
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
):
"""Quantized RMSNorm"""
out = torch.empty_like(x, dtype=torch.int8)
if is_per_token_smooth_quant():
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
out_scale = torch.empty(
x.shape[:-1], device=x.device, dtype=torch.float
).unsqueeze(-1)
else:
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
out=out, out_scale=out_scale)
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
return out, out_scale
@staticmethod
def smooth_quant_matmul_column_row_kernels(input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype):
def smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
otype,
):
"""smooth_quant_matmul_column_row_kernels"""
input_shape = input_tensor.shape
weight_shape = weight.shape
if input_tensor.dim() == 3:
input_tensor = input_tensor.reshape(-1, input_shape[-1])
out = torch.empty((input_shape[0] * input_shape[1],
weight_shape[0]),
dtype=torch.float16,
device=weight.device)
out = torch.empty(
(input_shape[0] * input_shape[1], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [input_shape[0], input_shape[1]]
elif input_tensor.dim() == 2:
out = torch.empty((input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device)
out = torch.empty(
(input_shape[0], weight_shape[0]),
dtype=torch.float16,
device=weight.device,
)
output_bs_shape = [-1]
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
weight, smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out)
kunlun_ops.smooth_quant_matmul_column_row_kernels(
input_tensor,
weight,
smoother,
input_scale,
weight_scale,
perTokenScaling,
perChannelScaling,
out=out,
)
out = out.view(*output_bs_shape, weight_shape[0])
@@ -411,6 +371,7 @@ class KunlunOps:
if torch.is_tensor(x):
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
return (type(x), x)
@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
@@ -427,23 +388,24 @@ class KunlunOps:
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""fused_moe"""
global_num_experts, up_gate_size, _ = w1.shape
M, N = hidden_states.shape
hidden_dim = w2.shape[1]
normed_score = torch.empty(M,
moe_top_k,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
moe_top_k,
dtype=torch.int32,
device=hidden_states.device)
normed_score = torch.empty(
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
num_blocks,
global_num_experts,
dtype=torch.int32,
device=hidden_states.device,
)
router_logits = router_logits.to(torch.float)
if scoring_func == "softmax":
@@ -452,24 +414,27 @@ class KunlunOps:
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True)
stable=False,
)
elif scoring_func == "sigmoid":
torch.ops._C.moe_sigmoid_group_topk_norm(
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
x=router_logits,
topk_index=topk_ids,
norm_score=normed_score,
block_static=block_statistic,
bias=e_score_correction_bias,
scale=1.0,
n_group=num_expert_group,
topk_group=topk_group,
)
if w1_bias is not None or w2_bias is not None:
if w1_bias is not None or w2_bias is not None:
# Rignt now this branch is for gpt oss
# TODO (@xyDong23): faster here using moe_fc kernel
normed_score = normed_score.to(hidden_states.dtype)
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
out = torch.zeros(
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
)
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
topk_ids_flat = topk_ids.flatten()
for i in range(global_num_experts):
@@ -477,9 +442,13 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
groupgemm1 = cur_token@ w1[i].T
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
groupgemm1 = cur_token @ w1[i].T
# Add w13 bias
if w1_bias is not None:
groupgemm1 = groupgemm1 + w1_bias[i]
@@ -489,53 +458,129 @@ class KunlunOps:
if w2_bias is not None:
groupgemm2 = groupgemm2 + w2_bias[i]
out[selected_token] = groupgemm2
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
ouput = (
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
.sum(dim=1)
.to(hidden_states.dtype)
)
return ouput
else:
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
# from vllm.forward_context import get_forward_context
# forward_context = get_forward_context()
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
# prefix = "model.layers.0.linear_attn"
# if attn_metadata is not None:
# attn_metadata = attn_metadata[prefix]
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod)
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
if M * moe_top_k < 400:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids, global_num_experts, False, False, hidden_states
)
)
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
topk_ids, global_num_experts, False
)
out = torch.ops.xspeedgate_ops.fused_moe(
hidden_states,
w1,
w2,
normed_score.to(hidden_states.dtype),
sorted_tokens_num_lod,
sorted_tokens_idx,
experts_num_lod,
)
return out.sum(1)
y = torch.empty(M,moe_top_k,
w1.shape[1],
if M * moe_top_k > 768:
moe_expand = torch.empty(
(M * moe_top_k, N),
dtype=hidden_states.dtype,
device=hidden_states.device)
device=hidden_states.device,
) # [M*top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
sorted_tokens_num_lod = torch.zeros(
global_num_experts + 1,
dtype=torch.int32,
device=hidden_states.device,
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
torch.ops._C.moe_pre_sorted(
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
else:
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
torch.ops.xspeedgate_ops.moe_pre_small(
topk_ids,
global_num_experts,
index_have_neg=False,
sort_mode=True,
x=hidden_states,
)
)
y = torch.empty(
M,
moe_top_k,
w1.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
if M < 1024:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
)
d = y.shape[-1] // 2
output_shape = y.shape[:-1] + (d,)
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out1 = out1.reshape(-1, out1.shape[-1])
else:
torch.ops._C.moe_fc(
x=moe_expand,
weight=w1,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_top_k,
y=y,
act="SWISH_GLU",
)
y = y[..., : y.shape[-1] // 2]
out1 = y.reshape(-1, y.shape[-1])
out = torch.empty(
M,
moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
d = y.shape[-1] // 2
output_shape = (y.shape[:-1] + (d, ))
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)
out = torch.empty(M,moe_top_k,
w2.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device)
out1 = out1.reshape(-1, out1.shape[-1])
torch.ops._C.moe_fc(
x=out1,
weight=w2,
@@ -545,8 +590,12 @@ class KunlunOps:
y=out,
)
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
dequant_scale = torch.ones(
[M, moe_top_k], dtype=torch.float32, device=out.device
)
output = torch.empty(
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
)
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
torch.ops._C.moe_post(
@@ -554,9 +603,9 @@ class KunlunOps:
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output
y=output,
)
return output
@staticmethod
@@ -575,23 +624,23 @@ class KunlunOps:
topk_group: Optional[int] = None,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> torch.Tensor:
x = hidden_states
batch, hidden_size = x.shape
batch, hidden_size = x.shape
num_local_experts, up_gate_size, _ = w13_weight.shape
router_logits = x.to(linear_weights.dtype)@linear_weights.T
topk_weights = torch.empty(batch,
top_k,
dtype=router_logits.dtype,
device=router_logits.device)
topk_ids = torch.empty(batch,
top_k,
dtype=torch.int32,
device=router_logits.device)
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
topk_weights = torch.empty(
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
)
topk_ids = torch.empty(
batch, top_k, dtype=torch.int32, device=router_logits.device
)
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
torch.ops._C.moe_softmax_topk(
router_logits, topk_weights, topk_ids, block_static
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
@@ -605,11 +654,19 @@ class KunlunOps:
selected_token = topk_ids_flat == experts_id
if selected_token.sum():
cur_token = repeat_x[selected_token]
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
dtype=cur_token.dtype, device=cur_token.device)
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T)
up_gate = torch.empty(
selected_token.sum(),
up_gate_size // 2,
dtype=cur_token.dtype,
device=cur_token.device,
)
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
out[selected_token] = up_gate @ w2_weight[i].T
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
output = (
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
.sum(dim=1)
.to(x.dtype)
)
return output
@@ -645,11 +702,12 @@ class KunlunOps:
prompt_lods_cpu: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
"""mla pa block"""
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
device=hidden_states.device)
xtorch_ops.xft_multi_head_latent_page_attention_block(
output = torch.empty(
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
)
kunlun_ops.xft_multi_head_latent_page_attention_block(
hidden_states,
q_lora_rank,
kv_lora_rank,
@@ -686,7 +744,6 @@ class KunlunOps:
)
return output
def fused_gdn_gating(
A_log: torch.Tensor,
a: torch.Tensor,
@@ -695,32 +752,41 @@ class KunlunOps:
threshold: float = 20.0,
) -> torch.Tensor:
"""fused_gdn_gating"""
output = xtorch_ops.fused_gdn_gating(
output = kunlun_ops.fused_gdn_gating(
A_log,
a,
dt_bias,
)
return output
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
'''
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制
'''
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
h0_source: torch.Tensor,
output_final_state: bool,
use_qk_l2norm_in_kernel: bool,
cu_seqlens: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
"""
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
cu_seqlens)
return (o, final_state)
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
q,
k,
v,
g,
beta,
scale,
h0_source,
output_final_state,
use_qk_l2norm_in_kernel,
cu_seqlens,
)
return (o, final_state)

View File

@@ -93,7 +93,7 @@ class SiluAndMul(CustomOp):
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""forward_cuda"""
import xtorch_ops
import kunlun_ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
@@ -103,7 +103,7 @@ class SiluAndMul(CustomOp):
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
"""forward_kunlun"""
import xtorch_ops
import kunlun_ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
@@ -251,14 +251,14 @@ class GeluAndMul(CustomOp):
无。
"""
# from vllm import _custom_ops as ops
import xtorch_ops
import kunlun_ops
# d = x.shape[-1] // 2
# output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(x, dtype=x.dtype, device=x.device)
if self.approximate == "none":
# ops.gelu_and_mul(out, x)
print(x,x.shape)
xtorch_ops.gelu(x, out)
kunlun_ops.gelu(x, out)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out

View File

@@ -7,7 +7,7 @@ import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
import xtorch_ops
import kunlun_ops
logger = init_logger(__name__)
@@ -104,7 +104,7 @@ def flash_mla_with_kvcache(
is_context = False
vo_head_dim = -1
xtorch_ops.paged_attention(out,
kunlun_ops.paged_attention(out,
q,
k_cache, None,
block_table,
@@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache(
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
"""
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:

View File

@@ -3,7 +3,7 @@
from typing import Optional
import torch
import xtorch_ops
import kunlun_ops
from vllm.platforms import current_platform
@@ -16,7 +16,7 @@ def merge_attn_states(
output_lse: Optional[torch.Tensor] = None,
) -> None:
return xtorch_ops.attention_merge_stage(
return kunlun_ops.attention_merge_stage(
prefix_output,
prefix_lse,
suffix_output,

View File

@@ -9,60 +9,196 @@
# ruff: noqa: E501
import warnings
from typing import Optional
import torch.nn.functional as F
import cocopod # noqa
import torch
import torch.distributed as dist
import torch.nn.functional as F
from einops import rearrange
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
from .index import prepare_chunk_indices
import xspeedgate_ops
import cocopod
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
chunk_size=64
A = -A.transpose(1,2)
def torch_solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
# A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = chunk_scaled_dot_kkt_fwd(k=k,
beta=beta,
g_cumsum=g,
cu_seqlens=cu_seqlens,
output_dtype=q.dtype)
#kernel版
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
chunk_indices = prepare_chunk_indices(
cu_seqlens, 64) if cu_seqlens is not None else None
def recompute_w_u_fwd_torch(
k: torch.Tensor, # [B, T, H, K]
v: torch.Tensor, # [B, T, H, V]
beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k,
v=v,
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
"""
w, u = recompute_w_u_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q,
k=k,
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64
chunk_size=64,
)
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type='cuda')
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False):
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@torch.compiler.disable
def chunk_gated_delta_rule(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False):
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert len(
beta.shape
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
assert (
q.dtype != torch.float32
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2)
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
(q, k, v, beta, g))
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2)
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if initial_state is not None and initial_state.shape[0] != len(
cu_seqlens) - 1:
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1]**-0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
use_qk_l2norm_in_kernel)
scale = k.shape[-1] ** -0.5
if False:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, 'b t h ... -> b h t ...')
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state

View File

@@ -12,21 +12,21 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics({
'USE_G': lambda args: args['g'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
# @triton.autotune(
# configs=[
# triton.Config({
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
# ],
# key=['H', 'K', 'V', 'BT'],
# )
@triton.jit(do_not_specialize=['T'])
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
chunk_indices + i_t * 2 + 1
).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
(BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
(BK, BT), (0, 1))
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
(BK, BV), (1, 0))
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * tl.exp(b_g)[:, None]
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
# b_A = tl.where(m_A, b_A, 0)
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
(BT, BV), (1, 0))
p_v = tl.make_block_ptr(
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
chunk_size: int = 64,
) -> torch.Tensor:
_, T, _, _, _ = *q.shape, v.shape[-1]
if FLA_GDN_FIX_BT:
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
def grid(meta):
return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=64,
BV=32
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
)
return o

View File

@@ -9,29 +9,29 @@
# ruff: noqa: E501
from typing import Optional
import kunlun_ops
import torch
import xtorch_ops
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
q.contiguous(),
k.contiguous(),
v.contiguous(),
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
h0_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
is_h0_transposed=True
is_h0_transposed=True,
)
return o, final_state
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:

View File

@@ -10,22 +10,21 @@
import os
from typing import Optional
import kunlun_ops
import torch
from vllm.triton_utils import tl, triton
import xtorch_ops
BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
x,
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
],
key=['D'])
@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd_triton(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta['BT']), )
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T, )](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)
def l2norm_fwd(x: torch.Tensor,
eps: float = 1e-6,
output_dtype: Optional[torch.dtype] = None):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
):
out = torch.empty_like(x)
xtorch_ops.l2norm(x, out, eps)
kunlun_ops.l2norm(x, out, eps)
return out

View File

@@ -19,20 +19,21 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vllm.triton_utils import tl, triton
from .utils import input_guard
def rms_norm_ref(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True):
def rms_norm_ref(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True,
):
dtype = x.dtype
weight = weight.float()
bias = bias.float() if bias is not None else None
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
weight)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
eps)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
return out.to(dtype)
@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
})
@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
@@ -149,46 +150,50 @@ def layer_norm_fwd(
# weight = weight.reshape(N)
# print("weight",weight.shape)
# print("x",x.shape)
assert weight.shape == (N, )
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
layer_norm_fwd_kernel[grid](x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps)
layer_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
# y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
y = torch.empty_like(x)
mean, rstd = None, None
import kunlun_ops
kunlun_ops.rms_norm_gated(
x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
return y.reshape(x_shape_og)
def layernorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, is_rms_norm)
def layernorm_fn(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
)
def rmsnorm_fn(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True):
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
norm_before_gate, True)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNormGated(nn.Module):
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return layernorm_fn(x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
)
class RMSNormGated(nn.Module):
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
torch.nn.init.ones_(self.weight)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return rmsnorm_fn(x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)

View File

@@ -11,7 +11,6 @@
from typing import Optional
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
@@ -28,6 +27,7 @@ RESOLUTION = {
torch.complex64: 1.3e-6,
}
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
assert res.dtype == dtype
ref = ref.to(dtype)
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
rtol = RESOLUTION[dtype]
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
# @triton.autotune(
# configs=[
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = prepare_chunk_indices(
cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
recompute_w_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
)
return w, u
return w, u

View File

@@ -15,51 +15,52 @@
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers import layernorm
from typing import Optional, Union
import xtorch_ops
import torch
from vllm.model_executor.layers import layernorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
def vllm_kunlun_forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if x.is_contiguous() == False:
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""forward_cuda"""
if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if self.variance_size_override is not None:
return self.forward_native(x, residual)
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
if residual is not None:
# residual_output = torch.empty_like(residual)
torch.ops._C.add_rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
residual,
residual_output=residual,
weight=self.weight.data,
eps=self.variance_epsilon,
output=x,
)
return out
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
self.weight.data,
out,
self.variance_epsilon,
)
return out
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
@staticmethod
def forward_xpu(
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
x: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if x.is_contiguous() == False:
if not x.is_contiguous():
# kunlun does not support uncontiguous input and they do not think it is a bug
# so we must make it contiguous() manually
x = x.contiguous()
if x.dim() == 3:
x_shape = x.shape
x = x.view(-1, x.size(-1))
if residual is not None:
torch.ops._C.add_rmsnorm(
out = torch.empty_like(x)
out_residual = torch.empty_like(residual)
torch.ops._C.gemma_add_rmsnorm(
x,
residual,
residual_output=residual,
weight=weight+1,
residual_output=out_residual,
weight=weight,
eps=variance_epsilon,
output=x
output=out,
)
else:
out = torch.empty_like(x)
torch.ops._C.gemma_rmsnorm(
x,
weight,
out,
variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
torch.ops._C.rmsnorm(
x,
weight+1,
out,
variance_epsilon,
)
return out
if x.dim() == 3:
x = x.view(x_shape)
if out is not None:
out = out.view(x_shape)
if residual is not None:
return out, out_residual
else:
return out
def forward_cuda(
self,
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if torch.compiler.is_compiling():
self.forward_static = self.forward_xpu # only use in cudagraph
self.forward_static = self.forward_xpu # only use in cudagraph
return self.forward_native(x, residual)
if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile( # type: ignore
self.forward_static, backend="aot_eager")
self.forward_static, backend="aot_eager"
)
self._is_compiled = True
return self.forward_native(x, residual)
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
RMSNorm.forward = vllm_kunlun_forward_cuda
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm

File diff suppressed because it is too large Load Diff

View File

@@ -113,7 +113,7 @@ class KunlunCompressedTensorsMoEMethod(FusedMoEMethodBase):
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# NOTE: xtorch_ops use max as scale
# NOTE: kunlun_ops use max as scale
with torch.no_grad():
layer.w13_weight_scale.mul_(127.0)
layer.w2_weight_scale.mul_(127.0)

View File

@@ -8,9 +8,6 @@ import vllm.envs as envs
from vllm.logger import init_logger
# fix bfloat16 double size issue
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
logger = init_logger(__name__)
class KunlunPlatform(Platform):

View File

@@ -0,0 +1,21 @@
#
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-kunlun project.
#
from . import tokenizer
__all__ = ["tokenizer"]

View File

@@ -0,0 +1,27 @@
from transformers import PretrainedConfig
from vllm.transformers_utils.config import LazyConfigDict, _CONFIG_REGISTRY
_XPU_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
chatglm="ChatGLMConfig",
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
glm_moe_dsa="DeepseekV3Config",
kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct)
jais="JAISConfig",
mlp_speculator="MLPSpeculatorConfig",
medusa="MedusaConfig",
midashenglm="MiDashengLMConfig",
eagle="EAGLEConfig",
speculators="SpeculatorsConfig",
nemotron="NemotronConfig",
olmo3="Olmo3Config",
ovis="OvisConfig",
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
step3_text="Step3TextConfig",
qwen3_next="Qwen3NextConfig",
)

View File

@@ -0,0 +1,223 @@
#
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import shutil
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
import huggingface_hub
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils import tokenizer
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
TokenizerBase = Any
logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, TokenizerBase]
def kunlun_get_tokenizer(
tokenizer_name: Union[str, Path],
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
revision: Optional[str] = None,
download_dir: Optional[str] = None,
**kwargs,
) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# avoid circuit import
from vllm.model_executor.model_loader.weight_utils import get_lock
# Only set the tokenizer here, model will be downloaded on the workers.
if not os.path.exists(tokenizer_name):
# Use file lock to prevent multiple processes from
# downloading the same file at the same time.
with get_lock(tokenizer_name, download_dir):
tokenizer_path = snapshot_download(
model_id=tokenizer_name,
cache_dir=download_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
# Ignore weights - we only need the tokenizer.
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
)
tokenizer_name = tokenizer_path
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent
# if tokenizer is from official mistral org
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
if is_from_mistral_org and tokenizer_mode != "mistral":
warnings.warn(
"It is strongly recommended to run mistral models with "
'`--tokenizer-mode "mistral"` to ensure correct '
"encoding and decoding.",
FutureWarning,
stacklevel=2,
)
tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision
)
elif tokenizer_mode == "custom":
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name),
*args,
revision=revision,
download_dir=download_dir,
**kwargs,
)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
# FIXME: Temporary compatibility code for new config format. Remove after vLLM upgrade.
if "TokenizersBackend" in str(e):
logger.warning(
"TokenizerBackend not supported, patching tokenizer_config.json "
"and loading with PreTrainedTokenizerFast."
)
tmp_dir = tempfile.mkdtemp(prefix="vllm_tokenizer_patch_")
try:
TOKENIZER_FILES = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"added_tokens.json",
"chat_template.jinja",
"generation_config.json",
]
for fname in TOKENIZER_FILES:
src = os.path.join(tokenizer_name, fname)
if os.path.exists(src):
shutil.copy(src, tmp_dir)
config_path = os.path.join(tmp_dir, "tokenizer_config.json")
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
if cfg.get("tokenizer_class") in ("TokenizersBackend",):
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
if "extra_special_tokens" in cfg:
cfg["additional_special_tokens"] = cfg.pop(
"extra_special_tokens"
)
with open(config_path, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=2)
tokenizer = AutoTokenizer.from_pretrained(
tmp_dir,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
else:
raise e
# The special_tokens in tokenizer should also be
# controlled by do_lower_case in encoder_config
encoder_config = get_sentence_transformer_tokenizer_config(
tokenizer_name, revision
)
if isinstance(encoder_config, dict) and encoder_config.get(
"do_lower_case", False
):
special_tokens_map = {
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
}
tokenizer.add_special_tokens(special_tokens_map)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
)
tokenizer = get_cached_tokenizer(tokenizer)
return tokenizer
tokenizer.get_tokenizer = kunlun_get_tokenizer
logger.info_once(
"[Monkey Patch Applied] >>> vllm.transformer_utils.tokenizer.get_tokenizer \
--> vllm_kunlun.transformer_utils.tokenizer.kunlun_get_tokenizer"
)

View File

@@ -0,0 +1,390 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention."""
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends import gdn_attn
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
class GDNAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
return GDNAttentionMetadataBuilder
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
has_initial_state_cpu: Optional[torch.Tensor] = None
spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [num_spec_decodes + 1,]
)
non_spec_query_start_loc: Optional[torch.Tensor] = (
None # shape: [batch - num_spec_decodes + 1,]
)
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
None # shape: [batch - num_spec_decodes,]
)
non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
spec_token_masks: Optional[torch.Tensor] = (
None # shape: [num_prefill_tokens + num_decode_tokens,]
)
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_masks = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.bool,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build( # type: ignore[override]
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
.sum()
.item()
== 0
):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True
)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
)
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
(
min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc[-1].item(),
)
),
dtype=torch.bool,
device=query_start_loc.device,
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
~spec_sequence_masks, 0
]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:],
)
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
has_initial_state_cpu = has_initial_state.cpu()
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
)
else:
has_initial_state = None
has_initial_state_cpu = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert spec_token_masks is not None
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True
)
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size
]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True
)
non_spec_num_query_tokens = non_spec_query_start_loc[
-1
] # type: ignore[index]
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
if num_accepted_tokens is not None:
num_accepted_tokens = num_accepted_tokens.to(torch.int32)
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
has_initial_state_cpu=has_initial_state_cpu,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
non_spec_state_indices_tensor_cpu=(
non_spec_state_indices_tensor.cpu()
if non_spec_state_indices_tensor is not None
else None
),
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata
gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder

View File

@@ -28,9 +28,9 @@ from typing import (
TypeVar,
)
import kunlun_ops
import numpy as np
import torch
import xtorch_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import (
AttentionType,
)
from vllm.config import VllmConfig
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
CommonAttentionMetadata,
@@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
def __post_init__(self):
"""__post_init__"""
self.attn_bias: Optional[List[AttentionBias]] = None
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
self.cross_attn_bias: Optional[List[AttentionBias]] = None
self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
@property
def is_all_encoder_attn_metadata_set(self):
@@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder:
"""build"""
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
common_prefix_len = common_prefix_len
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
self.device, non_blocking=True
@@ -770,28 +770,17 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
value = value.contiguous()
if key_cache.is_contiguous():
xtorch_ops.reshape_and_cache(
key[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens],
key_cache,
value_cache,
updated_slot_mapping,
)
else:
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
xtorch_ops.reshape_and_cache_flash(
key,
value,
cast_key_cache,
cast_value_cache,
updated_slot_mapping,
)
kunlun_ops.reshape_and_cache_flash(
key[: attn_metadata.num_actual_tokens],
value[: attn_metadata.num_actual_tokens],
key_cache,
value_cache,
updated_slot_mapping,
BLHD_LAYOUT=False,
)
assert attn_type == AttentionType.DECODER
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
@@ -811,7 +800,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
# Prefix cache
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
xtorch_ops.prefill_attention(
kunlun_ops.prefill_attention(
q=prefill_query,
k=key_cache, # Key Cache [block_num, head, block_size, dim]
v=value_cache,
@@ -827,7 +816,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
softmax_lse=None,
)
else:
xtorch_ops.prefill_attention(
kunlun_ops.prefill_attention(
q=prefill_query,
k=prefill_key,
v=prefill_value,
@@ -860,9 +849,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
decode_meta.block_tables * 2
) # only test in Qwen3-Next
sig = inspect.signature(xtorch_ops.speculative_attention)
sig = inspect.signature(kunlun_ops.speculative_attention)
if "max_window_size" in sig.parameters:
xtorch_ops.speculative_attention(
kunlun_ops.speculative_attention(
out=output[:num_decode_tokens],
# Only MLA support q len > 1 right now
q=decode_query.unsqueeze(0),
@@ -890,7 +879,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
),
)
elif not attn_metadata.is_speculative:
xtorch_ops.paged_attention(
kunlun_ops.paged_attention(
x=decode_query,
k_cache=key_cache,
v_cache=value_cache,
@@ -910,7 +899,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
out = output[:num_decode_tokens]
assert out.is_contiguous()
xtorch_ops.speculative_attention(
kunlun_ops.speculative_attention(
out=out.view(batch_size, qlen, head_num, self.head_size),
q=decode_query.view(batch_size, qlen, head_num, head_dim),
k_cache=key_cache,

View File

@@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
infer_global_hyperparameters,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
import xtorch_ops
import kunlun_ops
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
) * q_len
sorted_tokens_idx = torch.arange(
self.num_heads * q_len, dtype=torch.int, device="cuda")
xtorch_ops.mla_bmm_I8(
kunlun_ops.mla_bmm_I8(
x.contiguous(), # [1, 16, 512] torch.float16
self.W_UV, # [16, 128, 512] torch.int8
self.W_UV_SCALE, # [2048, 1] torch.float32
@@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
tp_q_head_num=q.size(1)
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
softmax_lse.fill_(float('-inf'))
xtorch_ops.attention(
kunlun_ops.attention(
q=q,
k_cache=k,
v_cache=maybe_padded_v,
@@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
dtype=torch.float, device=kv_b_proj_weight.device)
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
self.W_UV = W_UV.contiguous()
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
else:
@@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
xtorch_ops.concat_and_cache_mla(
kunlun_ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
attn_metadata.slot_mapping.flatten(),
@@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
sorted_tokens_idx = torch.arange(
self.num_heads * q_len, dtype=torch.int, device="cuda")
extra_params = {"trans": False}
xtorch_ops.mla_bmm_I8(
kunlun_ops.mla_bmm_I8(
decode_q_nope.contiguous(),
self.W_UK_T,
self.W_UK_SCALE,

View File

@@ -10,7 +10,7 @@ from packaging import version
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
import xtorch_ops
import kunlun_ops
import os
logger = init_logger(__name__)
@@ -200,16 +200,16 @@ def flashinfer_sample(
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None:
# Top-p only.
next_token_ids = xtorch_ops.top_p_sampling_from_probs(
next_token_ids = kunlun_ops.top_p_sampling_from_probs(
probs,top_p=p, deterministic=True)
elif p is None:
# Top-k only.
next_token_ids = xtorch_ops.top_k_sampling_from_probs(
next_token_ids = kunlun_ops.top_k_sampling_from_probs(
probs, top_k=k, deterministic=True)
else:
# Both top-k and top-p.
k = k.to(torch.int32)
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
next_token_ids = kunlun_ops.top_k_top_p_sampling_from_probs(
probs, top_k=k, top_p=p, deterministic=True)
return next_token_ids.view(-1)

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Union
import kunlun_ops
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
"""
Args:
metadata:
Metadata for spec decoding.
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
@@ -179,25 +178,15 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
)
else:
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
kunlun_ops.rejection_greedy_sample(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
if sampling_metadata.all_greedy:
return output_token_ids
@@ -222,8 +211,9 @@ def rejection_sample(
sampling_metadata,
device,
)
bonus_token_ids = bonus_token_ids.squeeze(1)
rejection_random_sample_pytorch(
kunlun_ops.rejection_random_sample(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -235,8 +225,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
# num_warps=1,
no_draft_probs=draft_probs is None,
)
return output_token_ids
@@ -374,7 +363,7 @@ def generate_uniform_probs(
random values in the range [0, 1).
"""
uniform_probs = torch.rand(
(num_tokens, ),
(num_tokens,),
dtype=torch.float32,
device=device,
)
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
q[i].exponential_(generator=generator)
recovered_token_ids = torch.empty_like(draft_token_ids)
sample_recovered_tokens_pytorch(
kunlun_ops.sample_recovered_tokens(
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
target_probs,
q,
vocab_size,
IS_NGRAM=draft_probs is None,
no_draft_probs=draft_probs is None,
)
return recovered_token_ids
def rejection_greedy_sample_spec_len_1_pytorch(
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
output_token_ids, # [batch_size, 2]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
accept_req_mask = draft_token_ids == target_argmax
output_token_ids[:, 0] = target_argmax
bonus_token_ids = bonus_token_ids.squeeze(1)
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
output_token_ids[:, 1])
output_token_ids[:, 1] = torch.where(
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
)
def rejection_greedy_sample_pytorch(
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
output_token_ids, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens, # [batch_size]
draft_token_ids, # [num_tokens]
target_argmax, # [num_tokens]
bonus_token_ids, # [batch_size]
draft_tokens_per_req, # [batch_size], list
max_spec_len,
is_greedy=None, # [batch_size] or None
):
batch_size = output_token_ids.size(0)
num_tokens = draft_token_ids.size(0)
device = output_token_ids.device
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
device, non_blocking=True)
device, non_blocking=True
)
if is_greedy is None:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
start_indices = cu_num_draft_tokens - draft_tokens_per_req
req_ids = torch.arange(batch_size, device=device)
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
token_positions = torch.arange(
num_tokens, device=device) - start_indices[token_req_ids]
token_positions = (
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
)
# Find the first mismatch position of each request.
mismatch_global = (draft_token_ids != target_argmax)
mismatch_global = draft_token_ids != target_argmax
if max_spec_len == 0:
first_mismatch_pos_per_req = torch.zeros(batch_size,
dtype=torch.long,
device=device)
first_mismatch_pos_per_req = torch.zeros(
batch_size, dtype=torch.long, device=device
)
else:
# [bs, max_spec_len]
pos_matrix = torch.full((batch_size, max_spec_len),
-1,
dtype=torch.long,
device=device)
pos_matrix = torch.full(
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
)
pos_matrix[token_req_ids, token_positions] = token_positions
mismatch_matrix = torch.full((batch_size, max_spec_len),
False,
dtype=torch.bool,
device=device)
mismatch_matrix = torch.full(
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
)
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
max_spec_len * 2)
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
no_mismatch_mask]
no_mismatch_mask
]
# Copy matched target tokens into output.
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1,
device=device).expand(batch_size, -1)
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
copy_mask = copy_indices < copy_len.unsqueeze(1)
greedy_mask = is_greedy.unsqueeze(1)
final_copy_mask = copy_mask & greedy_mask
global_idx = start_indices.unsqueeze(1) + copy_indices
output_token_ids[final_copy_mask] = target_argmax[
global_idx[final_copy_mask]].to(output_token_ids.dtype)
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
output_token_ids.dtype
)
# Fill bonus token.
needs_bonus = is_greedy & (first_mismatch_pos_per_req
>= draft_tokens_per_req)
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
if torch.any(needs_bonus):
bonus_rows = torch.where(needs_bonus)[0]
bonus_cols = draft_tokens_per_req[bonus_rows]
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
if IS_NGRAM:
draft_prob = 1.0
else:
draft_prob = draft_probs[start_idx + pos,
draft_token_id].item()
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
target_prob = target_probs[start_idx + pos,
draft_token_id].item()
target_prob = target_probs[start_idx + pos, draft_token_id].item()
uniform_prob = uniform_probs[start_idx + pos].item()
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
else:
draft_p = draft_probs[token_idx].clone()
target_p = target_probs[token_idx].clone()
prob = torch.maximum(target_p - draft_p,
torch.tensor(0.0, device=target_p.device))
prob = torch.maximum(
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
)
q_values = torch.full((vocab_size, ),
float('-inf'),
device=q.device)
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
q_values[:vocab_size] = q[req_idx, :vocab_size]
recovered_id = torch.argmax(prob / q_values).item()
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
if IS_NGRAM:
target_probs[token_idx, draft_token_id] = orig_prob

View File

@@ -337,5 +337,5 @@ def prepare_next_token_ids_padded(
return next_token_ids, valid_sampled_tokens_count
EagleProposer.propose = propose
# EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded

View File

@@ -386,8 +386,8 @@ def silu_and_mul_quant_xpu(
pass
import kunlun_ops # noqa: E402
import torch # noqa: E402
import xtorch_ops # noqa: E402
from torch.library import custom_op, impl # noqa: E402
@@ -405,9 +405,9 @@ def add_rmsnorm(
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.add_rmsnorm(
kunlun_ops.add_rmsnorm(
x,
y, # 原来写 residual这里其实是 y
y,
residual_output=residual_output,
weight=weight,
eps=eps,
@@ -429,7 +429,7 @@ def add_rmsnorm_cuda(
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.add_rmsnorm(
kunlun_ops.add_rmsnorm(
x,
y,
residual_output=residual_output,
@@ -451,7 +451,7 @@ def rmsnorm(
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.rmsnorm(
kunlun_ops.rmsnorm(
x,
weight,
output,
@@ -471,7 +471,7 @@ def rmsnorm_cuda(
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.rmsnorm(
kunlun_ops.rmsnorm(
x,
weight,
output,
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
add_rmsnorm.register_fake(_fake_add_rmsnorm)
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
def gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
@impl("_C::gemma_add_rmsnorm", "CUDA")
def gemma_add_rmsnorm_cuda(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_add_rmsnorm_cuda wrapper")
kunlun_ops.gemma_add_rmsnorm(
x,
y,
weight=weight,
output=output,
eps=eps,
enable_pdl=enable_pdl,
interweaved=interweaved,
store_output_before_norm=store_output_before_norm,
bias=bias,
smooth=smooth,
residual_output=residual_output,
force_sdnn=force_sdnn,
)
def _fake_gemma_add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
force_sdnn: bool = False,
):
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
@custom_op("_C::gemma_rmsnorm", mutates_args=())
def gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
@impl("_C::gemma_rmsnorm", "CUDA")
def gemma_rmsnorm_cuda(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
) -> None:
# print("gemma_rmsnorm_cuda wrapper")
kunlun_ops.gemma_rmsnorm(
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
)
def _fake_gemma_rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
enable_pdl: bool = False,
interweave: bool = False,
bias: torch.Tensor = None,
force_sdnn: bool = False,
):
# 设置 shape/dtype但不返回值
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
@custom_op("_C::split_norm_rope_neox", mutates_args=())
def split_norm_rope_neox(
q_emb: torch.Tensor,
@@ -541,7 +680,7 @@ def split_norm_rope_neox(
rotary_dim: int,
emb_batch_size: int = 1,
) -> None:
xtorch_ops.split_norm_rope_neox(
kunlun_ops.split_norm_rope_neox(
q_emb,
k_emb,
v_out,
@@ -577,7 +716,7 @@ def split_norm_rope_neox_cuda(
rotary_dim: int,
emb_batch_size: int = 1,
) -> None:
xtorch_ops.split_norm_rope_neox(
kunlun_ops.split_norm_rope_neox(
q_emb,
k_emb,
v_out,
@@ -649,7 +788,7 @@ if hasattr(torch.ops.custom_ops, "fc_fusion"):
def silu_and_mul(
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
) -> None:
xtorch_ops.swiglu(
kunlun_ops.swiglu(
x=x,
y=out,
)
@@ -659,7 +798,7 @@ def silu_and_mul(
def silu_and_mul_cuda(
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
) -> None:
xtorch_ops.swiglu(
kunlun_ops.swiglu(
x=x,
y=out,
)
@@ -736,7 +875,7 @@ def moe_softmax_topk(
axis: int = -1,
turn: bool = True,
) -> None:
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
@impl("_C::moe_softmax_topk", "CUDA")
@@ -748,7 +887,7 @@ def moe_softmax_topk_cuda(
axis: int = -1,
turn: bool = True,
) -> None:
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
def _fake_moe_softmax_topk(
@@ -781,7 +920,7 @@ def moe_ffn_block(
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.moe_ffn_block(
kunlun_ops.moe_ffn_block(
x=x,
gate_w=gate_w,
inter_w=inter_w,
@@ -812,7 +951,7 @@ def moe_ffn_block_cuda(
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.moe_ffn_block(
kunlun_ops.moe_ffn_block(
x=x,
gate_w=gate_w,
inter_w=inter_w,
@@ -863,7 +1002,7 @@ def moe_ffn_per_token_block(
ep_size: int = 1,
ep_rank: int = 0,
) -> None:
xtorch_ops.moe_ffn_per_token_block(
kunlun_ops.moe_ffn_per_token_block(
x=x,
inter_weight=inter_weight,
inter_scale=inter_scale,
@@ -897,7 +1036,7 @@ def moe_ffn_per_token_block_cuda(
ep_size: int = 1,
ep_rank: int = 0,
) -> None:
xtorch_ops.moe_ffn_per_token_block(
kunlun_ops.moe_ffn_per_token_block(
x=x,
inter_weight=inter_weight,
inter_scale=inter_scale,
@@ -948,7 +1087,7 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
xtorch_ops.rotary_embedding(
kunlun_ops.rotary_embedding(
positions=positions,
query=query,
key=key,
@@ -967,7 +1106,7 @@ def rotary_embedding_cuda(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
xtorch_ops.rotary_embedding(
kunlun_ops.rotary_embedding(
positions=positions,
query=query,
key=key,
@@ -999,7 +1138,7 @@ def gemm_I8_I8_bf16_nt(
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
kunlun_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
)
@@ -1012,7 +1151,7 @@ def gemm_I8_I8_bf16_nt_cuda(
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
kunlun_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
)
@@ -1038,7 +1177,7 @@ def moe_softmax_topk_norm(
block_statistic: torch.Tensor,
stable: bool = True,
) -> None:
xtorch_ops.moe_softmax_topk_norm(
kunlun_ops.moe_softmax_topk_norm(
x, normed_score, topk_index, block_statistic, stable
)
@@ -1051,7 +1190,7 @@ def moe_softmax_topk_norm_cuda(
block_statistic: torch.Tensor,
stable: bool = True,
) -> None:
xtorch_ops.moe_softmax_topk_norm(
kunlun_ops.moe_softmax_topk_norm(
x, normed_score, topk_index, block_statistic, stable
)
@@ -1071,14 +1210,14 @@ moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
@custom_op("_C::gen_block_statistic", mutates_args=())
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
@impl("_C::gen_block_statistic", "CUDA")
def gen_block_statistic_cuda(
topk_ids: torch.Tensor, block_statistic: torch.Tensor
) -> None:
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
def fake_gen_block_statistic(
@@ -1101,7 +1240,7 @@ def moe_pre_sorted(
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False,
) -> None:
xtorch_ops.moe_pre_sorted(
kunlun_ops.moe_pre_sorted(
x,
topk_index,
block_statistic,
@@ -1123,7 +1262,7 @@ def moe_pre_sorted_cuda(
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False,
) -> None:
xtorch_ops.moe_pre_sorted(
kunlun_ops.moe_pre_sorted(
x,
topk_index,
block_statistic,
@@ -1171,7 +1310,7 @@ def moe_fc(
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True,
) -> None:
xtorch_ops.moe_fc(
kunlun_ops.moe_fc(
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
@@ -1214,7 +1353,7 @@ def moe_fc_cuda(
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True,
) -> None:
xtorch_ops.moe_fc(
kunlun_ops.moe_fc(
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
@@ -1270,7 +1409,7 @@ def moe_post(
dequant_scale: torch.Tensor,
y: torch.Tensor,
) -> None:
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
@impl("_C::moe_post", "CUDA")
@@ -1281,7 +1420,7 @@ def moe_post_cuda(
dequant_scale: torch.Tensor,
y: torch.Tensor,
) -> None:
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
def fake_moe_post(
@@ -1308,7 +1447,7 @@ def moe_sigmoid_group_topk_norm(
n_group: int,
topk_group: int,
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
kunlun_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
@@ -1331,7 +1470,7 @@ def moe_sigmoid_group_topk_norm_cuda(
n_group: int,
topk_group: int,
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
kunlun_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
@@ -1376,7 +1515,7 @@ def awq_dequantize(
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
xtorch_ops.awq_dequantize(
kunlun_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
@@ -1402,7 +1541,7 @@ def awq_dequantize_cuda(
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
xtorch_ops.awq_dequantize(
kunlun_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
@@ -1447,7 +1586,7 @@ def awq_gemm(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
kunlun_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
@@ -1471,7 +1610,7 @@ def awq_gemm_cuda(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
kunlun_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
@@ -1508,7 +1647,7 @@ def gptq_shuffle(
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
@impl("_C::gptq_shuffle", "CUDA")
@@ -1517,7 +1656,7 @@ def gptq_shuffle_cuda(
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
def _fake_gptq_shuffle(
@@ -1541,7 +1680,7 @@ def concat_and_cache_mla(
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
) -> None:
xtorch_ops.concat_and_cache_mla(
kunlun_ops.concat_and_cache_mla(
kv_c=kv_c,
k_pe=k_pe,
slot_mapping=slot_mapping,
@@ -1556,7 +1695,7 @@ def concat_and_cache_mla_cuda(
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
) -> None:
xtorch_ops.concat_and_cache_mla(
kunlun_ops.concat_and_cache_mla(
kv_c=kv_c,
k_pe=k_pe,
slot_mapping=slot_mapping,
@@ -1598,7 +1737,7 @@ def scaled_int8_quant(
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
if symmetric:
# NOTE: For quant2d ops, scale represents max.
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
else:
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
x_q, x.contiguous(), scale, azp
@@ -1625,7 +1764,7 @@ def scaled_int8_quant_cuda(
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
if symmetric:
# NOTE: For quant2d ops, scale represents max.
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
else:
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
x_q, x.contiguous(), scale, azp
@@ -1777,7 +1916,7 @@ def matmul(
dtype=out_dtype,
device=x.device,
)
xtorch_ops.matmul(
kunlun_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=out,
@@ -1814,7 +1953,7 @@ def matmul_cuda(
dtype=out_dtype,
device=x.device,
)
xtorch_ops.matmul(
kunlun_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=out,
@@ -1865,7 +2004,7 @@ def quant2d(
max: torch.Tensor,
force_sdnn: bool = False,
) -> None:
xtorch_ops.quant2d(
kunlun_ops.quant2d(
x=x,
y=x_q,
max=max,
@@ -1880,7 +2019,7 @@ def quant2d_cuda(
max: torch.Tensor,
force_sdnn: bool = False,
) -> None:
xtorch_ops.quant2d(
kunlun_ops.quant2d(
x=x,
y=x_q,
max=max,
@@ -1954,7 +2093,7 @@ def I8_mqa_logits(
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
kunlun_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
@@ -1984,7 +2123,7 @@ def I8_mqa_logits_cuda(
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
kunlun_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
@@ -2034,7 +2173,7 @@ def I8_paged_mqa_logits(
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_paged_mqa_logits(
kunlun_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
@@ -2060,7 +2199,7 @@ def I8_paged_mqa_logits_cuda(
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_paged_mqa_logits(
kunlun_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
@@ -2111,7 +2250,7 @@ def sparse_prefill_fwd_opt(
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
kunlun_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
@@ -2147,7 +2286,7 @@ def sparse_prefill_fwd_opt_cuda(
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
kunlun_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
@@ -2207,7 +2346,7 @@ def fwd_kvcache_mla(
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.fwd_kvcache_mla(
kunlun_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
indices=indices,
@@ -2241,7 +2380,7 @@ def fwd_kvcache_mla_cuda(
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.fwd_kvcache_mla(
kunlun_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
indices=indices,
@@ -2293,7 +2432,7 @@ def dequant_int4(
int4_signed: bool = True,
use_mode_fast: bool = False,
) -> None:
xtorch_ops.dequant_int4(
kunlun_ops.dequant_int4(
x=x,
scale=scale,
zero=zero,
@@ -2315,7 +2454,7 @@ def dequant_int4_cuda(
int4_signed: bool = True,
use_mode_fast: bool = False,
) -> None:
xtorch_ops.dequant_int4(
kunlun_ops.dequant_int4(
x=x,
scale=scale,
zero=zero,
@@ -2350,7 +2489,7 @@ def fast_topkv2(
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
return topk_indices
@@ -2359,7 +2498,7 @@ def fast_topkv2_cuda(
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
return topk_indices
@@ -2798,7 +2937,7 @@ def lora_matmul_inplace(
alpha: float = 1.0,
beta: float = 1.0,
) -> None:
xtorch_ops.matmul(
kunlun_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=output_tensor,
@@ -2819,7 +2958,7 @@ def lora_matmul_inplace_cuda(
alpha: float = 1.0,
beta: float = 1.0,
) -> None:
xtorch_ops.matmul(
kunlun_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=output_tensor,