Compare commits
16 Commits
v0.11.0-v0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34e04c5569 | ||
|
|
a15754c3ba | ||
|
|
4d8575115a | ||
|
|
dc63e81a7f | ||
|
|
e4c9b9f988 | ||
|
|
171f664a0f | ||
|
|
82544aa0cc | ||
|
|
153093d3b3 | ||
|
|
d425a0d0e9 | ||
|
|
b82b6026d6 | ||
|
|
a470452871 | ||
|
|
d9ad42a174 | ||
|
|
77dbc2ddeb | ||
|
|
76ec220b43 | ||
|
|
bf9369f733 | ||
|
|
744719587e |
@@ -1,4 +1,4 @@
|
||||
FROM wjie520/vllm_kunlun:base_v0.0.2
|
||||
FROM vllm_kunlun:custom_base_v0.0.3
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
||||
11
README.md
11
README.md
@@ -11,11 +11,12 @@ One of the key features of this project is efficient memory coordination, enabli
|
||||
|
||||
### Build from Dockerfile
|
||||
|
||||
Clone this repository:
|
||||
1. Get or build base image (base with customized xpytorch, ops, etc.). Ref: [installation](https://vllm-kunlun.readthedocs.io/en/latest/installation.html).
|
||||
|
||||
```bash
|
||||
docker build -t $build_image -f ./Dockerfile .
|
||||
```
|
||||
2. Clone this repository and build
|
||||
```bash
|
||||
docker build -t $build_image -f ./Dockerfile .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -25,4 +26,4 @@ docker build -t $build_image -f ./Dockerfile .
|
||||
|
||||
### Environment Variables
|
||||
- `VXPU_RESERVED_VRAM_SIZE_GB`: The amonut of reserved GPU memory for other miscellaneous memory. Only needs to be set for `vllm_vxpu_daemon`. Try increasing the variable if you launch multiple LLM services and encounter OOM. Default: `8`.
|
||||
- `VLLM_VXPU_SHM_NAME`: The name of the shm file. Needs to be set for all containers of the shared vxpu group. Default: `/vllm_kunlun_vxpu_offload_shm`.
|
||||
- `VLLM_VXPU_SHM_NAME`: The name of the shm file. Needs to be set for all containers of the shared vxpu group. Default: `/vllm_kunlun_vxpu_offload_shm`.
|
||||
337
README.md.bak
337
README.md.bak
@@ -1,212 +1,199 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
<a href="https://vllm-kunlun.readthedocs.io/en/latest/"><b> Documentation</b></a> |
|
||||
<a href="https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html"><b> Quick Start</b></a> |
|
||||
<a href="https://join.slack.com/t/vllm-kunlun/shared_invite/zt-3iinb8u5z-FcqZKbNNdMJ_32fHmipzvw"><b> Slack</b></a>
|
||||
<a href="https://vllm-kunlun.readthedocs.io/en/latest/"><b>📖 Documentation</b></a> |
|
||||
<a href="https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html"><b>🚀 Quick Start</b></a> |
|
||||
<a href="https://vllm-kunlun.readthedocs.io/en/latest/installation.html"><b>📦 Installation</b></a> |
|
||||
<a href="https://join.slack.com/t/vllm-kunlun/shared_invite/zt-3iinb8u5z-FcqZKbNNdMJ_32fHmipzvw"><b>💬 Slack</b></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img alt="GitHub License" src="https://img.shields.io/github/license/baidu/vLLM-Kunlun">
|
||||
<img alt="GitHub Stars" src="https://img.shields.io/github/stars/baidu/vLLM-Kunlun">
|
||||
<img alt="GitHub Forks" src="https://img.shields.io/github/forks/baidu/vLLM-Kunlun">
|
||||
<img alt="GitHub Issues" src="https://img.shields.io/github/issues/baidu/vLLM-Kunlun">
|
||||
<img alt="Python Version" src="https://img.shields.io/badge/python-%3E%3D3.10-blue">
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Latest News 🔥
|
||||
- [2025/12] Initial release of vLLM Kunlun
|
||||
|
||||
- [2026/02] 🧠 **GLM model family support** — Added GLM5, GLM-4.7 MTP (Multi-Token Prediction), and GLM-47 tool parser with thinking/non-thinking mode toggle
|
||||
- [2026/02] ⚡ **Performance optimizations** — Fused MoE with small batches, optimized attention metadata building, Multi-LoRA inference achieves 80%+ of non-LoRA performance
|
||||
- [2026/02] 🔧 **DeepSeek-V3.2 MTP support** — Added MTP (Multi-Token Prediction) for DeepSeek-V3.2, with RoPE and decoding stage kernel optimizations
|
||||
- [2026/01] 🔢 **New quantization methods** — Support for compressed-tensors W4A16, AWQ MoE W4A16, and DeepSeek-V3.2 W8A8 quantization
|
||||
- [2026/01] 🛠️ **CI/CD overhaul** — Added E2E tests, unit test CI, ruff format checks, and modular CI workflow refactoring
|
||||
- [2025/12] 🎉 **v0.11.0rc1 released** — Added Qwen3-Omni, Qwen3-Next, Seed-OSS support ([Release Notes](https://github.com/baidu/vLLM-Kunlun/releases/tag/v0.11.0rc1))
|
||||
- [2025/12] 📦 **v0.10.1.1 released** — 5+ multimodal models, AWQ/GPTQ quantization for dense models, Piecewise CUDA Graph, vLLM V1 engine, Flash-Infer Top-K/Top-P sampling with 10-100× speedup ([Release Notes](https://github.com/baidu/vLLM-Kunlun/releases/tag/v0.10.1.1))
|
||||
- [2025/12] 🌟 Initial release of vLLM Kunlun — Open sourced on Dec 8, 2025
|
||||
|
||||
---
|
||||
|
||||
# Overview
|
||||
## Overview
|
||||
|
||||
vLLM Kunlun (vllm-kunlun) is a community-maintained hardware plugin designed to seamlessly run vLLM on the Kunlun XPU. It is the recommended approach for integrating the Kunlun backend within the vLLM community, adhering to the principles outlined in the [RFC Hardware pluggable](https://github.com/vllm-project/vllm/issues/11162). This plugin provides a hardware-pluggable interface that decouples the integration of the Kunlun XPU with vLLM.
|
||||
**vLLM Kunlun** (`vllm-kunlun`) is a community-maintained hardware plugin designed to seamlessly run [vLLM](https://github.com/vllm-project/vllm) on the **Kunlun XPU**. It is the recommended approach for integrating the Kunlun backend within the vLLM community, adhering to the principles outlined in the [RFC Hardware Pluggable](https://github.com/vllm-project/vllm/issues/11162).
|
||||
|
||||
By utilizing the vLLM Kunlun plugin, popular open-source models, including Transformer-like, Mixture-of-Expert, Embedding, and Multi-modal LLMs, can run effortlessly on the Kunlun XPU.
|
||||
This plugin provides a hardware-pluggable interface that decouples the integration of the Kunlun XPU with vLLM. By utilizing vLLM Kunlun, popular open-source models — including Transformer-like, Mixture-of-Expert (MoE), Embedding, and Multi-modal LLMs — can run effortlessly on the Kunlun XPU.
|
||||
|
||||
### ✨ Key Features
|
||||
|
||||
- **Seamless Plugin Integration** — Works as a standard vLLM platform plugin via Python entry points, no need to modify vLLM source code
|
||||
- **Broad Model Support** — Supports 15+ mainstream LLMs including Qwen, Llama, DeepSeek, Kimi-K2, and multimodal models
|
||||
- **Quantization Support** — INT8 and other quantization methods for MoE and dense models
|
||||
- **LoRA Fine-Tuning** — LoRA adapter support for Qwen series models
|
||||
- **Piecewise Kunlun Graph** — Hardware-accelerated graph optimization for high-performance inference
|
||||
- **FlashMLA Attention** — Optimized multi-head latent attention for DeepSeek MLA architectures
|
||||
- **Tensor Parallelism** — Multi-device parallel inference with distributed execution support
|
||||
- **OpenAI-Compatible API** — Serve models with the standard OpenAI API interface
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Hardware**: Kunlun3 P800
|
||||
- **OS**: Ubuntu 22.04
|
||||
- **Hardware**: Kunlun3 P800
|
||||
- **OS**: Ubuntu 22.04
|
||||
- **Software**:
|
||||
- Python >=3.10
|
||||
- PyTorch ≥ 2.5.1
|
||||
- Python >= 3.10
|
||||
- PyTorch >= 2.5.1
|
||||
- vLLM (same version as vllm-kunlun)
|
||||
- transformers >= 4.57.0
|
||||
|
||||
---
|
||||
|
||||
## Supported Models
|
||||
|
||||
<h3>Generaltive Models</h3>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th width="30%">Model</th>
|
||||
<th width="12%">Support</th>
|
||||
<th width="15%">Quantization</th>
|
||||
<th width="10%">LoRA</th>
|
||||
<th width="20%">Piecewise Kunlun Graph</th>
|
||||
<th width="23%">Note</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td class="model-name">Qwen2</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Qwen2.5</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Qwen3</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Qwen3-Moe</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Qwen3-Next</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">MiMo-V2-Flash</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Llama2</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Llama3</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Llama3.1</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">gpt-oss</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">DeepSeek-R1</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">DeepSeek-V3</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">DeepSeek-V3.2</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="model-name">Kimi-K2</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
### Generative Models
|
||||
|
||||
<h3>Multimodal Language Models</h3>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th width="20%">Model</th>
|
||||
<th width="12%">Support</th>
|
||||
<th width="15%">Quantization</th>
|
||||
<th width="10%">LoRA</th>
|
||||
<th width="20%">Piecewise Kunlun Graph</th>
|
||||
<th width="23%">Note</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td class="model-name">Qwen3-VL</td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
<td></td>
|
||||
<td class="status-support">✅</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
| Model | Support | Quantization | LoRA | Kunlun Graph |
|
||||
|:------|:-------:|:------------:|:----:|:----------------------:|
|
||||
| Qwen2 | ✅ | ✅| ✅ | ✅ |
|
||||
| Qwen2.5 | ✅ |✅ | ✅ | ✅ |
|
||||
| Qwen3 | ✅ |✅ | ✅ | ✅ |
|
||||
| Qwen3-Moe | ✅ | ✅ | | ✅ |
|
||||
| Qwen3-Next | ✅ | ✅ | | ✅ |
|
||||
| MiMo-V2-Flash | ✅ | ✅| | ✅ |
|
||||
| Llama2 | ✅ | ✅| ✅| ✅ |
|
||||
| Llama3 | ✅ |✅ | ✅ | ✅ |
|
||||
| Llama3.1 | ✅ |✅ | | ✅ |
|
||||
| gpt-oss | ✅ | ✅| | |
|
||||
| GLM4.5 | ✅ | ✅| | ✅ |
|
||||
| GLM4.5Air | ✅ |✅ | | ✅ |
|
||||
| GLM4.7 | ✅ | ✅| | ✅ |
|
||||
| GLM5 | ✅ | ✅| | ✅ |
|
||||
| Kimi-K2 | ✅ | ✅ | | ✅ |
|
||||
| DeepSeek-R1 | ✅ | ✅ | | ✅ |
|
||||
| DeepSeek-V3 | ✅ | ✅ | | ✅ |
|
||||
| DeepSeek-V3.2 | ✅ | ✅ | | ✅ |
|
||||
|
||||
### Multimodal Language Models
|
||||
|
||||
| Model | Support | Quantization | LoRA | Kunlun Graph |
|
||||
|:------|:-------:|:------------:|:----:|:----------------------:|
|
||||
| Qwen2-VL | ✅ | ✅| | ✅ |
|
||||
| Qwen2.5-VL | ✅ | ✅| | ✅ |
|
||||
| Qwen3-VL | ✅ | ✅| | ✅ |
|
||||
| Qwen3-VL-MoE | ✅ | ✅ | | ✅ |
|
||||
| Qwen3-Omni-MoE | ✅ | | | ✅ |
|
||||
| InternVL-2.5 | ✅ | | | ✅ |
|
||||
| InternVL-3.5 | ✅ | | | ✅ |
|
||||
| InternS1 | ✅ | | | ✅ |
|
||||
|
||||
---
|
||||
|
||||
## Performance Visualization 🚀
|
||||
|
||||
### High-performance computing at work: How different models perform on the Kunlun3 P800.
|
||||
|
||||
Current environment: 16-way concurrency, input/output size 2048.
|
||||
|
||||
|
||||

|
||||
|
||||
## Getting Started
|
||||
|
||||
Please use the following recommended versions to get started quickly:
|
||||
|
||||
| Version | Release type | Doc |
|
||||
|----------|---------------|-----|
|
||||
| v0.11.0 | Latest stable version | [QuickStart](https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html) and [Installation](https://vllm-kunlun.readthedocs.io/en/latest/installation.html) for more details |
|
||||
|
||||
---
|
||||
|
||||
## Contribute to vLLM Kunlun
|
||||
### Quick Start
|
||||
|
||||
If you're interested in contributing to this project, please read [Contributing](CONTRIBUTING.md) to vLLM Kunlun.
|
||||
#### Start an OpenAI-Compatible API Server
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 \
|
||||
--port 8356 \
|
||||
--model <your-model-path> \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--trust-remote-code \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 1 \
|
||||
--dtype float16 \
|
||||
--max_num_seqs 128 \
|
||||
--max_num_batched_tokens 32768 \
|
||||
--block-size 128 \
|
||||
--distributed-executor-backend mp \
|
||||
--served-model-name <your-model-name>
|
||||
```
|
||||
|
||||
#### Send a Request
|
||||
|
||||
```bash
|
||||
curl http://localhost:8356/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "<your-model-name>",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"max_tokens": 512
|
||||
}'
|
||||
```
|
||||
|
||||
### Version Matrix
|
||||
|
||||
| Version | Release Type | Documentation |
|
||||
|---------|:------------:|:-------------:|
|
||||
| v0.11.0 | Latest stable version | [Quick Start](https://vllm-kunlun.readthedocs.io/en/latest/quick_start.html) · [Installation](https://vllm-kunlun.readthedocs.io/en/latest/installation.html) |
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
vllm-kunlun/
|
||||
├── vllm_kunlun/ # Core plugin package
|
||||
│ ├── platforms/ # Kunlun XPU platform implementation
|
||||
│ ├── models/ # Model implementations (DeepSeek, Qwen, Llama, etc.)
|
||||
│ ├── ops/ # Custom operators (attention, linear, sampling, etc.)
|
||||
│ │ ├── attention/ # FlashMLA, paged attention, merge attention states
|
||||
│ │ ├── fla/ # Flash linear attention operations
|
||||
│ │ └── sample/ # Sampling operators
|
||||
│ ├── v1/ # vLLM V1 engine adaptations
|
||||
│ ├── compilation/ # Torch compile wrapper for Kunlun Graph
|
||||
│ ├── csrc/ # C++ extensions (custom CUDA-compatible kernels)
|
||||
│ └── config/ # Model configuration overrides
|
||||
├── tests/ # Test suite
|
||||
├── docs/ # Documentation (Sphinx-based, ReadTheDocs hosted)
|
||||
├── ci/ # CI pipeline configurations
|
||||
├── setup.py # Legacy build script (with C++ extensions)
|
||||
└── pyproject.toml # Modern Python build configuration (hatchling)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions from the community! Please read our [Contributing Guide](CONTRIBUTING.md) before submitting a PR.
|
||||
|
||||
### PR Classification
|
||||
|
||||
Use the following prefixes for PR titles:
|
||||
|
||||
- `[Attention]` — Attention mechanism features/optimizations
|
||||
- `[Core]` — Core vllm-kunlun logic (platform, attention, communicators, model runner)
|
||||
- `[Kernel]` — Compute kernels and ops
|
||||
- `[Bugfix]` — Bug fixes
|
||||
- `[Doc]` — Documentation improvements
|
||||
- `[Test]` — Tests
|
||||
- `[CI]` — CI/CD improvements
|
||||
- `[Misc]` — Other changes
|
||||
|
||||
---
|
||||
|
||||
## Star History 🔥
|
||||
|
||||
@@ -214,10 +201,14 @@ We opened the project at Dec 8, 2025. We love open source and collaboration ❤
|
||||
|
||||
[](https://www.star-history.com/#baidu/vLLM-Kunlun&type=date&legend=bottom-right)
|
||||
|
||||
---
|
||||
|
||||
## Sponsors 👋
|
||||
|
||||
We sincerely appreciate the [**KunLunXin**](https://www.kunlunxin.com/) team for their support in providing XPU resources, which enabled efficient model adaptation debugging, comprehensive end-to-end testing, and broader model compatibility.
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
Apache License 2.0, as found in the [LICENSE](./LICENSE) file.
|
||||
Apache License 2.0, as found in the [LICENSE](./LICENSE) file.
|
||||
695
collect_env.py
Normal file
695
collect_env.py
Normal file
@@ -0,0 +1,695 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# vLLM-Kunlun Environment Information Collection Tool (Fixed Version)
|
||||
"""
|
||||
Environment information collection script for Kunlun XPU
|
||||
Fixed the following issues:
|
||||
1. Device name displayed as "GPU" → Now correctly shows "P800 OAM"
|
||||
2. XRE version command error → Now parsed from xpu-smi output
|
||||
3. vLLM-Kunlun version hardcoded → Now fetched from pip package metadata
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
# =============================================================================
|
||||
# Part 1: Basic Utility Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def run(command):
|
||||
"""
|
||||
Execute shell command and return result
|
||||
[Principle Explanation - Web Development Analogy]
|
||||
This is like the fetch() function in frontend development, sending a request and getting a response.
|
||||
- command: The command to execute (similar to a URL)
|
||||
- returns: (return_code, stdout, stderr)
|
||||
Args:
|
||||
command: Command as string or list
|
||||
Returns:
|
||||
tuple: (return_code, stdout, stderr)
|
||||
"""
|
||||
shell = True if isinstance(command, str) else False
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE, # Capture standard output
|
||||
stderr=subprocess.PIPE, # Capture error output
|
||||
shell=shell,
|
||||
)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
# Decode byte stream to string
|
||||
output = raw_output.decode("utf-8").strip()
|
||||
err = raw_err.decode("utf-8").strip()
|
||||
return rc, output, err
|
||||
except FileNotFoundError:
|
||||
return 127, "", "Command not found"
|
||||
|
||||
|
||||
def run_and_read_all(run_lambda, command):
|
||||
"""Execute command, return output if successful, None otherwise"""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out
|
||||
|
||||
|
||||
def run_and_parse_first_match(run_lambda, command, regex):
|
||||
"""Execute command and extract first regex match"""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
match = re.search(regex, out)
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
# Check if PyTorch is available
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Part 2: General System Information Collection (Reusing vLLM Original Logic)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_platform():
|
||||
"""Get operating system platform"""
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif sys.platform.startswith("win32"):
|
||||
return "win32"
|
||||
elif sys.platform.startswith("darwin"):
|
||||
return "darwin"
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
"""Get detailed operating system information"""
|
||||
from platform import machine
|
||||
|
||||
if get_platform() == "linux":
|
||||
# Try reading /etc/*-release
|
||||
rc, out, _ = run_lambda(
|
||||
"cat /etc/*-release 2>/dev/null | grep PRETTY_NAME | head -1"
|
||||
)
|
||||
if rc == 0 and out:
|
||||
match = re.search(r'PRETTY_NAME="(.*)"', out)
|
||||
if match:
|
||||
return f"{match.group(1)} ({machine()})"
|
||||
# Fallback: use lsb_release
|
||||
rc, out, _ = run_lambda("lsb_release -d 2>/dev/null")
|
||||
if rc == 0 and out:
|
||||
match = re.search(r"Description:\s*(.*)", out)
|
||||
if match:
|
||||
return f"{match.group(1)} ({machine()})"
|
||||
return f"{get_platform()} ({machine()})"
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
"""Get GCC version"""
|
||||
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
"""Get Clang version"""
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "clang --version", r"clang version (.*)"
|
||||
)
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
"""Get CMake version"""
|
||||
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
"""Get libc version"""
|
||||
import platform
|
||||
|
||||
if get_platform() != "linux":
|
||||
return "N/A"
|
||||
return "-".join(platform.libc_ver())
|
||||
|
||||
|
||||
def get_python_platform():
|
||||
"""Get Python platform information"""
|
||||
import platform
|
||||
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
"""Get CPU information"""
|
||||
if get_platform() == "linux":
|
||||
rc, out, err = run_lambda("lscpu")
|
||||
return out if rc == 0 else err
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
"""Get pip package list"""
|
||||
if patterns is None:
|
||||
patterns = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"triton",
|
||||
"transformers",
|
||||
"vllm",
|
||||
"kunlun",
|
||||
"xpu",
|
||||
"bkcl",
|
||||
"xmlir",
|
||||
}
|
||||
|
||||
cmd = [sys.executable, "-mpip", "list", "--format=freeze"]
|
||||
out = run_and_read_all(run_lambda, cmd)
|
||||
if out is None:
|
||||
return "pip3", ""
|
||||
|
||||
filtered = "\n".join(
|
||||
line
|
||||
for line in out.splitlines()
|
||||
if any(name.lower() in line.lower() for name in patterns)
|
||||
)
|
||||
pip_version = "pip3" if sys.version[0] == "3" else "pip"
|
||||
return pip_version, filtered
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
"""Get conda package list"""
|
||||
if patterns is None:
|
||||
patterns = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"triton",
|
||||
"transformers",
|
||||
"kunlun",
|
||||
"xpu",
|
||||
"bkcl",
|
||||
"xmlir",
|
||||
}
|
||||
|
||||
conda = os.environ.get("CONDA_EXE", "conda")
|
||||
out = run_and_read_all(run_lambda, [conda, "list"])
|
||||
if out is None:
|
||||
return None
|
||||
|
||||
return "\n".join(
|
||||
line
|
||||
for line in out.splitlines()
|
||||
if not line.startswith("#")
|
||||
and any(name.lower() in line.lower() for name in patterns)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Part 3: Kunlun-Specific Information Collection (Core Fix)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def parse_xpu_smi_output(run_lambda):
|
||||
"""
|
||||
Parse the complete output of xpu-smi command
|
||||
[Principle Explanation]
|
||||
The xpu-smi output format is similar to nvidia-smi, we need to parse it with regex.
|
||||
Example output format:
|
||||
+-----------------------------------------------------------------------------+
|
||||
| XPU-SMI Driver Version: 515.58 XPU-RT Version: N/A |
|
||||
|-------------------------------+----------------------+----------------------+
|
||||
| 0 P800 OAM N/A | 00000000:52:00.0 N/A | 0 |
|
||||
| N/A 43C N/A 85W / 400W | 4MiB / 98304MiB | 0% Default |
|
||||
Returns:
|
||||
dict: Dictionary containing parsing results
|
||||
"""
|
||||
rc, output, _ = run_lambda("xpu-smi")
|
||||
if rc != 0 or not output:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"raw_output": output,
|
||||
"driver_version": None,
|
||||
"xre_version": None,
|
||||
"devices": [],
|
||||
}
|
||||
|
||||
# Parse header: Driver Version and XPU-RT Version
|
||||
# Format: | XPU-SMI Driver Version: 515.58 XPU-RT Version: N/A |
|
||||
header_match = re.search(
|
||||
r"Driver Version:\s*(\S+)\s+XPU-RT Version:\s*(\S+)", output
|
||||
)
|
||||
if header_match:
|
||||
result["driver_version"] = header_match.group(1)
|
||||
xre = header_match.group(2)
|
||||
result["xre_version"] = xre if xre != "N/A" else None
|
||||
|
||||
# Parse device information
|
||||
# Format: | 0 P800 OAM N/A | 00000000:52:00.0 N/A |
|
||||
# Following: | N/A 43C N/A 85W / 400W | 4MiB / 98304MiB |
|
||||
|
||||
# Find all device lines (containing device ID and name)
|
||||
device_pattern = re.compile(
|
||||
r"\|\s*(\d+)\s+(\S+(?:\s+\S+)?)\s+(?:N/A|On|Off)\s*\|" # ID and Name
|
||||
r"\s*([0-9a-fA-F:\.]+)\s*" # Bus-Id
|
||||
)
|
||||
|
||||
# Find memory information
|
||||
memory_pattern = re.compile(
|
||||
r"\|\s*N/A\s+\d+C\s+N/A\s+\d+W\s*/\s*\d+W\s*\|"
|
||||
r"\s*(\d+)MiB\s*/\s*(\d+)MiB\s*\|" # Memory-Usage / Total
|
||||
)
|
||||
|
||||
lines = output.split("\n")
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
device_match = device_pattern.search(line)
|
||||
if device_match:
|
||||
device_id = int(device_match.group(1))
|
||||
device_name = device_match.group(2).strip()
|
||||
bus_id = device_match.group(3)
|
||||
|
||||
# Next line should have memory info
|
||||
memory_used = 0
|
||||
memory_total = 0
|
||||
if i + 1 < len(lines):
|
||||
mem_match = memory_pattern.search(lines[i + 1])
|
||||
if mem_match:
|
||||
memory_used = int(mem_match.group(1))
|
||||
memory_total = int(mem_match.group(2))
|
||||
|
||||
result["devices"].append(
|
||||
{
|
||||
"id": device_id,
|
||||
"name": device_name, # This will correctly get "P800 OAM"
|
||||
"bus_id": bus_id,
|
||||
"memory_used_mib": memory_used,
|
||||
"memory_total_mib": memory_total,
|
||||
}
|
||||
)
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_kunlun_gpu_info(run_lambda):
|
||||
"""
|
||||
Get Kunlun XPU device information
|
||||
[Fix Explanation]
|
||||
Previously used torch.cuda.get_device_properties() to get the name,
|
||||
but it only returns "GPU" (because Kunlun masquerades as CUDA).
|
||||
Now parse xpu-smi output to correctly get "P800 OAM".
|
||||
Returns:
|
||||
str: Device information string
|
||||
"""
|
||||
parsed = parse_xpu_smi_output(run_lambda)
|
||||
|
||||
if parsed and parsed["devices"]:
|
||||
# Get real device name from xpu-smi parsing
|
||||
lines = []
|
||||
for dev in parsed["devices"]:
|
||||
memory_gb = dev["memory_total_mib"] / 1024
|
||||
# Correctly display: XPU 0: P800 OAM (96.0GB)
|
||||
lines.append(f"XPU {dev['id']}: {dev['name']} ({memory_gb:.1f}GB)")
|
||||
return "\n".join(lines)
|
||||
|
||||
# Fallback: Use PyTorch interface (but will display as GPU)
|
||||
if TORCH_AVAILABLE:
|
||||
try:
|
||||
device_count = torch.cuda.device_count()
|
||||
lines = []
|
||||
for i in range(device_count):
|
||||
props = torch.cuda.get_device_properties(i)
|
||||
name = props.name if hasattr(props, "name") else "Kunlun XPU"
|
||||
memory_gb = (
|
||||
props.total_memory / (1024**3)
|
||||
if hasattr(props, "total_memory")
|
||||
else 0
|
||||
)
|
||||
lines.append(f"XPU {i}: {name} ({memory_gb:.1f}GB)")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_kunlun_driver_version(run_lambda):
|
||||
"""
|
||||
Get Kunlun driver version
|
||||
[Fix Explanation]
|
||||
Parse directly from xpu-smi output header instead of calling incorrect commands.
|
||||
Returns:
|
||||
str: Driver version, e.g., "515.58"
|
||||
"""
|
||||
parsed = parse_xpu_smi_output(run_lambda)
|
||||
if parsed and parsed["driver_version"]:
|
||||
return parsed["driver_version"]
|
||||
return None
|
||||
|
||||
|
||||
def get_kunlun_xre_version(run_lambda):
|
||||
"""
|
||||
Get Kunlun XRE (Runtime) version
|
||||
[Fix Explanation]
|
||||
Previously used `xpu-smi --version` but that parameter doesn't exist.
|
||||
Now parse the "XPU-RT Version" field from xpu-smi standard output header.
|
||||
Returns:
|
||||
str: XRE version, or None (if showing N/A)
|
||||
"""
|
||||
parsed = parse_xpu_smi_output(run_lambda)
|
||||
if parsed and parsed["xre_version"]:
|
||||
return parsed["xre_version"]
|
||||
return "N/A (not installed or not detected)"
|
||||
|
||||
|
||||
def get_kunlun_topo(run_lambda):
|
||||
"""
|
||||
Get Kunlun XPU topology information
|
||||
Returns:
|
||||
str: Topology information
|
||||
"""
|
||||
# xpu-smi topo -m command can get topology
|
||||
output = run_and_read_all(run_lambda, "xpu-smi topo -m")
|
||||
if output:
|
||||
return output
|
||||
|
||||
# Fallback: Show device count
|
||||
if TORCH_AVAILABLE:
|
||||
try:
|
||||
count = torch.cuda.device_count()
|
||||
return f"Detected {count} Kunlun XPU device(s)"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_bkcl_version(run_lambda):
|
||||
"""
|
||||
Get BKCL (communication library) version
|
||||
[Principle Explanation]
|
||||
BKCL = Baidu Kunlun Communication Library
|
||||
Similar to NVIDIA's NCCL, used for multi-card communication.
|
||||
Returns:
|
||||
str: BKCL version information
|
||||
"""
|
||||
# Method 1: From your logs, BKCL prints version when loading
|
||||
# [WARN][BKCL][globals.cpp:268] xccl version: 6ab4ffb [rdma] ...
|
||||
# We can try importing related modules
|
||||
try:
|
||||
# Try getting from torch_xmlir
|
||||
import torch_xmlir
|
||||
|
||||
# Find path to libbkcl.so
|
||||
bkcl_path = None
|
||||
if hasattr(torch_xmlir, "__file__"):
|
||||
import os
|
||||
|
||||
base = os.path.dirname(torch_xmlir.__file__)
|
||||
candidate = os.path.join(base, "libbkcl.so")
|
||||
if os.path.exists(candidate):
|
||||
bkcl_path = candidate
|
||||
if bkcl_path:
|
||||
return f"Found at: {bkcl_path}"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Method 2: Search from ldconfig
|
||||
rc, out, _ = run_lambda("ldconfig -p 2>/dev/null | grep -i bkcl | head -1")
|
||||
if rc == 0 and out:
|
||||
return out
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_vllm_kunlun_version():
|
||||
"""
|
||||
Get vLLM-Kunlun version
|
||||
[Fix Explanation]
|
||||
Previously got hardcoded version "0.9.2" from vllm_kunlun.platforms.version,
|
||||
but actual pip installed version is "0.1.0".
|
||||
Now prioritize using importlib.metadata to get real installed version.
|
||||
Returns:
|
||||
str: Version number
|
||||
"""
|
||||
# Method 1 (recommended): Use importlib.metadata (Python 3.8+)
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version("vllm-kunlun")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 2: Use pkg_resources
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
return pkg_resources.get_distribution("vllm-kunlun").version
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Method 3 (fallback): Get from code (may be inaccurate)
|
||||
try:
|
||||
from vllm_kunlun.platforms.version import get_xvllm_version
|
||||
|
||||
return get_xvllm_version() + " (from code, may be inaccurate)"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_vllm_version():
|
||||
"""Get vLLM main package version"""
|
||||
try:
|
||||
from importlib.metadata import version
|
||||
|
||||
return version("vllm")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from vllm import __version__
|
||||
|
||||
return __version__
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return "N/A"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Part 4: Environment Variable Collection
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_kunlun_env_vars():
|
||||
"""Get Kunlun-related environment variables"""
|
||||
env_vars = ""
|
||||
kunlun_prefixes = (
|
||||
"XPU",
|
||||
"KUNLUN",
|
||||
"BKCL",
|
||||
"XCCL",
|
||||
"XRE",
|
||||
"TORCH",
|
||||
"VLLM",
|
||||
)
|
||||
secret_terms = ("secret", "token", "api", "access", "password")
|
||||
|
||||
for k, v in sorted(os.environ.items()):
|
||||
if any(term in k.lower() for term in secret_terms):
|
||||
continue
|
||||
if any(k.upper().startswith(prefix) for prefix in kunlun_prefixes):
|
||||
env_vars += f"{k}={v}\n"
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Part 5: Define Data Structure and Formatted Output
|
||||
# =============================================================================
|
||||
|
||||
KunlunSystemEnv = namedtuple(
|
||||
"KunlunSystemEnv",
|
||||
[
|
||||
# General system information
|
||||
"os",
|
||||
"gcc_version",
|
||||
"clang_version",
|
||||
"cmake_version",
|
||||
"libc_version",
|
||||
"python_version",
|
||||
"python_platform",
|
||||
"pip_version",
|
||||
"pip_packages",
|
||||
"conda_packages",
|
||||
"cpu_info",
|
||||
# PyTorch information
|
||||
"torch_version",
|
||||
"is_debug_build",
|
||||
# Kunlun-specific information
|
||||
"kunlun_xpu_info",
|
||||
"kunlun_driver_version",
|
||||
"kunlun_xre_version",
|
||||
"bkcl_version",
|
||||
"kunlun_topo",
|
||||
# vLLM related
|
||||
"vllm_version",
|
||||
"vllm_kunlun_version",
|
||||
"env_vars",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_kunlun_env_info():
|
||||
"""Collect all environment information"""
|
||||
run_lambda = run
|
||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
||||
|
||||
# PyTorch information
|
||||
if TORCH_AVAILABLE:
|
||||
torch_version = torch.__version__
|
||||
debug_mode_str = str(torch.version.debug)
|
||||
else:
|
||||
torch_version = "N/A"
|
||||
debug_mode_str = "N/A"
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
return KunlunSystemEnv(
|
||||
# General system information
|
||||
os=get_os(run_lambda),
|
||||
gcc_version=get_gcc_version(run_lambda),
|
||||
clang_version=get_clang_version(run_lambda),
|
||||
cmake_version=get_cmake_version(run_lambda),
|
||||
libc_version=get_libc_version(),
|
||||
python_version=f"{sys_version} ({sys.maxsize.bit_length() + 1}-bit runtime)",
|
||||
python_platform=get_python_platform(),
|
||||
pip_version=pip_version,
|
||||
pip_packages=pip_list_output,
|
||||
conda_packages=get_conda_packages(run_lambda),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
# PyTorch information
|
||||
torch_version=torch_version,
|
||||
is_debug_build=debug_mode_str,
|
||||
# Kunlun-specific information
|
||||
kunlun_xpu_info=get_kunlun_gpu_info(run_lambda),
|
||||
kunlun_driver_version=get_kunlun_driver_version(run_lambda),
|
||||
kunlun_xre_version=get_kunlun_xre_version(run_lambda),
|
||||
bkcl_version=get_bkcl_version(run_lambda),
|
||||
kunlun_topo=get_kunlun_topo(run_lambda),
|
||||
# vLLM related
|
||||
vllm_version=get_vllm_version(),
|
||||
vllm_kunlun_version=get_vllm_kunlun_version(),
|
||||
env_vars=get_kunlun_env_vars(),
|
||||
)
|
||||
|
||||
|
||||
# Output format template
|
||||
kunlun_env_info_fmt = """
|
||||
==============================
|
||||
System Info
|
||||
==============================
|
||||
OS : {os}
|
||||
GCC version : {gcc_version}
|
||||
Clang version : {clang_version}
|
||||
CMake version : {cmake_version}
|
||||
Libc version : {libc_version}
|
||||
==============================
|
||||
PyTorch Info
|
||||
==============================
|
||||
PyTorch version : {torch_version}
|
||||
Is debug build : {is_debug_build}
|
||||
==============================
|
||||
Python Environment
|
||||
==============================
|
||||
Python version : {python_version}
|
||||
Python platform : {python_platform}
|
||||
==============================
|
||||
Kunlun / XPU Info
|
||||
==============================
|
||||
XPU models and configuration :
|
||||
{kunlun_xpu_info}
|
||||
Kunlun driver version : {kunlun_driver_version}
|
||||
XRE (Runtime) version : {kunlun_xre_version}
|
||||
BKCL version : {bkcl_version}
|
||||
XPU Topology:
|
||||
{kunlun_topo}
|
||||
==============================
|
||||
CPU Info
|
||||
==============================
|
||||
{cpu_info}
|
||||
==============================
|
||||
Versions of relevant libraries
|
||||
==============================
|
||||
{pip_packages}
|
||||
{conda_packages}
|
||||
==============================
|
||||
vLLM-Kunlun Info
|
||||
==============================
|
||||
vLLM Version : {vllm_version}
|
||||
vLLM-Kunlun Version : {vllm_kunlun_version}
|
||||
==============================
|
||||
Environment Variables
|
||||
==============================
|
||||
{env_vars}
|
||||
""".strip()
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
"""Format environment information"""
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# Replace None with "Could not collect"
|
||||
for key in mutable_dict:
|
||||
if mutable_dict[key] is None:
|
||||
mutable_dict[key] = "Could not collect"
|
||||
|
||||
# Handle pip package list
|
||||
if mutable_dict["pip_packages"]:
|
||||
mutable_dict["pip_packages"] = "\n".join(
|
||||
f"[{envinfo.pip_version}] {line}"
|
||||
for line in mutable_dict["pip_packages"].split("\n")
|
||||
if line
|
||||
)
|
||||
else:
|
||||
mutable_dict["pip_packages"] = "No relevant packages"
|
||||
|
||||
# Handle conda package list
|
||||
if mutable_dict["conda_packages"]:
|
||||
mutable_dict["conda_packages"] = "\n".join(
|
||||
f"[conda] {line}"
|
||||
for line in mutable_dict["conda_packages"].split("\n")
|
||||
if line
|
||||
)
|
||||
else:
|
||||
mutable_dict["conda_packages"] = ""
|
||||
|
||||
return kunlun_env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
def get_pretty_kunlun_env_info():
|
||||
"""Get formatted environment information"""
|
||||
return pretty_str(get_kunlun_env_info())
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
print("Collecting Kunlun XPU environment information...")
|
||||
output = get_pretty_kunlun_env_info()
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
760
docs/source/developer_guide/developer_guide.md
Normal file
760
docs/source/developer_guide/developer_guide.md
Normal file
@@ -0,0 +1,760 @@
|
||||
# 📖 vLLM-Kunlun New Model Adaptation Manual
|
||||
|
||||
> Based on in-depth analysis of [baidu/vLLM-Kunlun](https://github.com/baidu/vLLM-Kunlun) and [vllm-project/vllm](https://github.com/vllm-project/vllm) repositories.
|
||||
>
|
||||
> Applicable Versions: vLLM v0.15.1+ / vLLM-Kunlun main branch
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [I. Understanding the Overall Architecture](#i-understanding-the-overall-architecture)
|
||||
- [1.1 Plugin System](#11-plugin-system)
|
||||
- [1.2 Startup Process](#12-startup-process)
|
||||
- [1.3 Import Hook Mechanism](#13-import-hook-mechanism)
|
||||
- [1.4 Code Architecture](#14-code-architecture)
|
||||
- [II. New Model Adaptation Step-by-Step](#ii-new-model-adaptation-step-by-step)
|
||||
- [Step 0: Pre-assessment](#step-0-pre-assessment)
|
||||
- [Step 1: Implement Model Files](#step-1-implement-model-files)
|
||||
- [Step 2: Register the Model](#step-2-register-the-model)
|
||||
- [Step 3: Verify Registration](#step-3-verify-registration)
|
||||
- [Step 4: Testing](#step-4-testing)
|
||||
- [III. Adaptation Guide for Special Model Types](#iii-adaptation-guide-for-special-model-types)
|
||||
- [3.1 MoE Models](#31-moe-models-eg-qwen3-moe-deepseek-v3)
|
||||
- [3.2 MLA Models](#32-mla-multi-latent-attention-models-eg-deepseek-v3)
|
||||
- [3.3 Multi-modal Models](#33-multi-modal-models-eg-qwen2-vl-internvl)
|
||||
- [3.4 Hybrid Attention Models](#34-hybrid-attention-models-eg-qwen3-next)
|
||||
- [IV. Quantized Model Adaptation](#iv-quantized-model-adaptation)
|
||||
- [4.1 Supported Quantization Methods](#41-supported-quantization-methods)
|
||||
- [4.2 Special Handling for Quantization](#42-special-handling-for-quantization)
|
||||
- [V. Custom Operators](#v-custom-operators-if-new-low-level-ops-are-needed)
|
||||
- [VI. Common Pitfalls Checklist](#vi-common-pitfalls-checklist)
|
||||
- [VII. Reference Template Quick Look-up](#vii-reference-template-quick-look-up)
|
||||
- [VIII. Debugging Tips](#viii-debugging-tips)
|
||||
- [IX. Environment Variables Cheat Sheet](#ix-environment-variables-cheat-sheet)
|
||||
- [X. PR Submission Standards](#x-pr-submission-standards)
|
||||
|
||||
---
|
||||
|
||||
## I. Understanding the Overall Architecture
|
||||
|
||||
### 1.1 Plugin System
|
||||
|
||||
vLLM-Kunlun uses the **OOT (Out-of-Tree) Plugin** approach to integrate with vLLM, primarily registered via `entry_points` in `setup.py`:
|
||||
|
||||
```python
|
||||
# setup.py
|
||||
entry_points={
|
||||
'vllm.platform_plugins': ["kunlun = vllm_kunlun:register"], # Platform Plugin
|
||||
'vllm.general_plugins': [
|
||||
"kunlun_model = vllm_kunlun:register_model", # Model Registration
|
||||
"kunlun_quant = vllm_kunlun:register_quant_method" # Quantization Method
|
||||
],
|
||||
"console_scripts": [
|
||||
"vllm_kunlun = vllm_kunlun.entrypoints.main:main"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 1.2 Startup Process
|
||||
|
||||
```
|
||||
vllm Startup
|
||||
├─ 1. Discover platform_plugin → Call vllm_kunlun:register()
|
||||
│ ├─ Register KunlunPlatform (defines Attention Backend, Worker, etc.)
|
||||
│ ├─ Apply import hook (module redirection)
|
||||
│ └─ Register custom operators (custom_op)
|
||||
├─ 2. Discover general_plugin → Call vllm_kunlun:register_model()
|
||||
│ └─ Register all Kunlun-adapted models via ModelRegistry.register_model()
|
||||
└─ 3. Model Loading → Match registered model classes based on the architectures field in config.json
|
||||
```
|
||||
|
||||
### 1.3 Import Hook Mechanism
|
||||
|
||||
vLLM-Kunlun uses a custom import hook to **transparently replace** certain vLLM modules with Kunlun-customized versions:
|
||||
|
||||
```python
|
||||
# vllm_kunlun/__init__.py
|
||||
def _custom_import(module_name, globals=None, locals=None, fromlist=(), level=0):
|
||||
try:
|
||||
module_mappings = {
|
||||
"vllm.compilation.wrapper": "vllm_kunlun.compilation.wrapper",
|
||||
"vllm.v1.worker.utils": "vllm_kunlun.v1.worker.utils",
|
||||
"vllm.model_executor.model_loader.bitsandbytes_loader": "vllm_kunlun.models.model_loader.bitsandbytes_loader",
|
||||
"vllm.v1.sample.ops.topk_topp_sampler": "vllm_kunlun.v1.sample.ops.topk_topp_sampler",
|
||||
"vllm.model_executor.layers.sampler": "vllm_kunlun.ops.sample.sampler",
|
||||
"vllm.v1.sample.rejection_sampler": "vllm_kunlun.v1.sample.rejection_sampler",
|
||||
"vllm.attention.ops.merge_attn_states": "vllm_kunlun.ops.attention.merge_attn_states",
|
||||
}
|
||||
|
||||
if module_name in module_mappings:
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
target_module = module_mappings[module_name]
|
||||
module = importlib.import_module(target_module)
|
||||
sys.modules[module_name] = module
|
||||
sys.modules[target_module] = module
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return OLD_IMPORT_HOOK(module_name, globals=globals, locals=locals, fromlist=fromlist, level=level)
|
||||
```
|
||||
|
||||
> **⚠️ Understanding this mechanism is crucial**: Even if you use `from vllm.xxx import YYY` in your model code, what you actually get might be `vllm_kunlun.xxx.YYY`.
|
||||
|
||||
### 1.4 Code Architecture
|
||||
|
||||
```
|
||||
vllm_kunlun/
|
||||
├── __init__.py # Plugin Entry: register() + import_hook()
|
||||
├── platforms/kunlun.py # KunlunPlatform: Defines Attention Backend, Worker, etc.
|
||||
├── models/ # ⭐ Model Implementation Directory (where you add files)
|
||||
│ ├── __init__.py # ⭐ Model Registration Entry
|
||||
│ ├── deepseek_v2.py # DeepSeek V2/V3 Reference Implementation
|
||||
│ ├── deepseek_mtp.py # DeepSeek MTP (Speculative Decoding)
|
||||
│ ├── qwen3.py # Qwen3 Reference Implementation (Dense Model)
|
||||
│ ├── qwen3_moe.py # Qwen3 MoE Reference Implementation
|
||||
│ ├── qwen3_next.py # Qwen3-Next (Hybrid Attention)
|
||||
│ ├── qwen3_vl.py # Qwen3 VL (Multi-modal)
|
||||
│ ├── qwen3_vl_moe.py # Qwen3 VL MoE (Multi-modal + MoE)
|
||||
│ ├── qwen2_vl.py # Qwen2 VL
|
||||
│ ├── qwen2_5_vl.py # Qwen2.5 VL
|
||||
│ ├── internlm2.py # InternLM2 Reference Implementation
|
||||
│ ├── internvl.py # InternVL (Multi-modal)
|
||||
│ ├── interns1.py # InternS1
|
||||
│ ├── seed_oss.py # SeedOss
|
||||
│ ├── gpt_oss.py # GptOss
|
||||
│ └── mimo_v2_flash.py # MiMo-V2-Flash
|
||||
├── ops/ # Kunlun Custom Operators
|
||||
│ ├── _kunlun_ops.py # KunlunOps: paged_attention, rms_norm, silu...
|
||||
│ ├── _custom_ops.py # vllm custom_op registration
|
||||
│ ├── activation.py # Activation functions like SiluAndMul, GeluAndMul
|
||||
│ ├── attention/ # Attention Operators
|
||||
│ │ ├── layer.py # Attention Layer Wrapper
|
||||
│ │ └── backends/kunlun_attn.py # KunlunAttentionBackend + KunlunAttentionImpl
|
||||
│ ├── quantization/ # Quantization related: AWQ, GPTQ, CompressedTensors...
|
||||
│ ├── vocab_parallel_embedding.py # Custom Embedding
|
||||
│ └── rotary_embedding.py # Split_Norm_Rope (QKNorm + RoPE Fusion)
|
||||
├── v1/attention/backends/ # Attention Backend for v1 Engine
|
||||
│ ├── kunlun_attn.py # Standard Attention
|
||||
│ └── mla/ # MLA (Multi-Latent Attention) Implementation
|
||||
│ ├── flashmla.py
|
||||
│ ├── flashmla_sparse.py
|
||||
│ └── common.py
|
||||
├── compilation/wrapper.py # torch.compile Wrapper
|
||||
├── config/ # Model Configuration Overrides
|
||||
│ └── model.py # Patch for attributes like is_deepseek_mla
|
||||
├── distributed/ # Communication related
|
||||
│ └── kunlun_communicator.py # Kunlun Device Communication
|
||||
└── csrc/ # C++ Extensions
|
||||
└── utils.cpp
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## II. New Model Adaptation Step-by-Step
|
||||
|
||||
### Step 0: Pre-assessment
|
||||
|
||||
Before starting, confirm which scenario your model falls into:
|
||||
|
||||
| Scenario | Description | Effort |
|
||||
|------|------|--------|
|
||||
| **Case A: vLLM already supports the model** | Only need to replace Attention / Activation with Kunlun versions | ⭐ Minimal |
|
||||
| **Case B: vLLM does not support, new architecture needed** | Requires full implementation of model class + registration | ⭐⭐⭐ High |
|
||||
| **Case C: MoE variant of an existing model** | Add MoE layer on top of the Dense version | ⭐⭐ Medium |
|
||||
| **Case D: Multi-modal model** | Language Model + Vision Encoder + Projector | ⭐⭐⭐⭐ Maximum |
|
||||
|
||||
**Recommended Workflow:**
|
||||
|
||||
1. Check the [vLLM Supported Models List](https://docs.vllm.ai/en/stable/models/supported_models.html) to see if the model is already there.
|
||||
2. If yes → Copy the corresponding file from `vllm/model_executor/models/` to `vllm_kunlun/models/` and perform replacements.
|
||||
3. If no → Refer to the [vLLM Adding a New Model Documentation](https://docs.vllm.ai/en/stable/contributing/model/) to understand the principles first, then follow this manual.
|
||||
|
||||
---
|
||||
|
||||
### Step 1: Implement Model Files
|
||||
|
||||
Create a model file in the `vllm_kunlun/models/` directory, e.g., `my_new_model.py`.
|
||||
|
||||
#### 1.1 Key Replacement Comparison Table
|
||||
|
||||
| Component | vLLM Native Import | vLLM-Kunlun Replacement Import | Required? |
|
||||
|------|-----------------|------------------------|---------|
|
||||
| **Attention Layer** | `from vllm.attention import Attention` | `from vllm_kunlun.ops.attention.layer import Attention` | ✅ **Yes** |
|
||||
| **SiluAndMul** | `from vllm.model_executor.layers.activation import SiluAndMul` | `from vllm_kunlun.ops.activation import SiluAndMul` | ✅ **Yes** |
|
||||
| **GeluAndMul** | `...activation import GeluAndMul` | `from vllm_kunlun.ops.activation import GeluAndMul` | ⚠️ As needed |
|
||||
| **QuickGELU** | `...activation import QuickGELU` | `from vllm_kunlun.ops.activation import QuickGELU` | ⚠️ As needed |
|
||||
| **VocabParallelEmbedding** | `from vllm...vocab_parallel_embedding import VocabParallelEmbedding` | `from vllm_kunlun.ops.vocab_parallel_embedding import VocabParallelEmbedding` | ⚠️ Some models |
|
||||
| **ParallelLMHead** | Same as above | `from vllm_kunlun.ops.vocab_parallel_embedding import ParallelLMHead` | ⚠️ Some models |
|
||||
| **RoPE (Special)** | `from vllm...rotary_embedding import get_rope` | `from vllm_kunlun.ops.rotary_embedding import Split_Norm_Rope` | ⚠️ MoE+QKNorm |
|
||||
| **Linear / RMSNorm, etc.** | Use vLLM native directly | **No replacement needed** | — |
|
||||
|
||||
> 💡 **Core Principle**: Any component involving **CUDA kernel calls** (Attention, Activation, Sampling) must be replaced with the Kunlun version; pure PyTorch components (Linear, RMSNorm, RoPE, etc.) can use vLLM native directly.
|
||||
|
||||
#### 1.2 Standard Dense Decoder-Only Model Template
|
||||
|
||||
Refer to `qwen3.py` or `internlm2.py`:
|
||||
|
||||
```python
|
||||
"""Inference-only MyNewModel compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import MyNewModelConfig # HuggingFace config
|
||||
|
||||
# ==========================================
|
||||
# ⭐ Key Replacement 1: Use Kunlun-customized Attention
|
||||
# ==========================================
|
||||
# Do not use from vllm.attention import Attention
|
||||
from vllm_kunlun.ops.attention.layer import Attention
|
||||
|
||||
# ==========================================
|
||||
# ⭐ Key Replacement 2: Use Kunlun-customized Activation
|
||||
# ==========================================
|
||||
# Do not use from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm_kunlun.ops.activation import SiluAndMul
|
||||
|
||||
# Other layers can use vLLM native directly
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear, RowParallelLinear, MergedColumnParallelLinear
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.model_executor.models.interfaces import SupportsPP, SupportsLoRA
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
|
||||
make_layers, maybe_prefix
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# 1. MLP Layer
|
||||
# ============================
|
||||
class MyNewModelMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, hidden_act,
|
||||
quant_config=None, prefix=""):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size,
|
||||
bias=False, quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.act_fn = SiluAndMul() # ⭐ Use Kunlun version
|
||||
|
||||
def forward(self, x):
|
||||
# Implementation...
|
||||
```
|
||||
|
||||
#### 1.3 Key Implementation Requirements
|
||||
|
||||
- **All modules must include the `prefix` parameter**, passed in `__init__()`.
|
||||
- **`@support_torch_compile` decorator** must be added to the main model class (e.g., `MyNewModel`).
|
||||
- **`load_weights()` method** must correctly handle weight name mapping (`stacked_params_mapping`).
|
||||
- **Pipeline Parallelism (PP)** requires using tools like `PPMissingLayer`, `is_pp_missing_parameter`, etc.
|
||||
|
||||
---
|
||||
|
||||
## Step 2: Register the Model
|
||||
|
||||
Add registration code in `vllm_kunlun/models/__init__.py`:
|
||||
|
||||
```python
|
||||
# vllm_kunlun/models/__init__.py
|
||||
|
||||
from vllm import ModelRegistry
|
||||
|
||||
def register_model():
|
||||
# ... Existing model registrations ...
|
||||
|
||||
# ⭐ Add your new model (using lazy loading string format)
|
||||
ModelRegistry.register_model(
|
||||
"MyNewModelForCausalLM", # ← Must match architectures in config.json
|
||||
"vllm_kunlun.models.my_new_model:MyNewModelForCausalLM" # ← Module path:Class name
|
||||
)
|
||||
```
|
||||
|
||||
**⚠️ Key Considerations:**
|
||||
|
||||
1. The **first parameter** of `register_model()` is the model's `architecture` identifier, which **must exactly match the `"architectures"` field in the HuggingFace model's `config.json`**.
|
||||
|
||||
2. Use the **string format** for the module path (`"module:class"`) to implement **lazy loading**, avoiding CUDA initialization conflicts (`RuntimeError: Cannot re-initialize CUDA in forked subprocess`).
|
||||
|
||||
3. If the model already exists in vLLM (e.g., `Qwen3ForCausalLM`), the Kunlun version will **overwrite** the original vLLM version upon registration.
|
||||
|
||||
---
|
||||
|
||||
## Step 3: Verify Registration
|
||||
|
||||
### Case A: Overwriting an Existing vLLM Model Architecture
|
||||
|
||||
If your model architecture name (e.g., `"Qwen3ForCausalLM"`) already exists in vLLM, vLLM will output the following log during registration:
|
||||
|
||||
```
|
||||
WARNING [...] Model architecture Qwen3ForCausalLM is already registered,
|
||||
and will be overwritten by the new model class
|
||||
vllm_kunlun.models.qwen3:Qwen3ForCausalLM.
|
||||
```
|
||||
|
||||
Seeing this log indicates a successful overwrite ✅.
|
||||
|
||||
### Case B: Brand New Model Architecture
|
||||
|
||||
If you are registering an architecture that does not exist in vLLM, there is no default log confirmation. It is recommended to verify manually during the debugging phase:
|
||||
|
||||
```python
|
||||
from vllm import ModelRegistry
|
||||
assert "MyNewModelForCausalLM" in ModelRegistry.get_supported_archs()
|
||||
print("✅ Model registration successful!")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Step 4: Testing
|
||||
|
||||
### 4.1 Offline Inference Test
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(
|
||||
model="/path/to/MyNewModel",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
tensor_parallel_size=1, # Verify with single card first
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
["Hello, please introduce yourself."],
|
||||
SamplingParams(temperature=0.7, max_tokens=256),
|
||||
)
|
||||
for output in outputs:
|
||||
print(output.outputs[0].text)
|
||||
```
|
||||
|
||||
#### 4.2 Online Service Test
|
||||
|
||||
```bash
|
||||
XPU_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 --port 8888 \
|
||||
--model /path/to/MyNewModel \
|
||||
--trust-remote-code \
|
||||
--dtype float16 \
|
||||
--max-model-len 4096 \
|
||||
--block-size 64
|
||||
```
|
||||
|
||||
#### 4.3 Accuracy Verification
|
||||
|
||||
It is recommended to compare results with HuggingFace Transformers CPU/GPU inference:
|
||||
|
||||
```python
|
||||
# Transformers reference output
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("/path/to/MyNewModel", torch_dtype=torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained("/path/to/MyNewModel")
|
||||
# ... Generate and compare output
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## III. Adaptation Guide for Special Model Types
|
||||
|
||||
### 3.1 MoE Models (e.g., Qwen3-MoE, DeepSeek-V3)
|
||||
|
||||
**Reference Files:**
|
||||
- `vllm_kunlun/models/qwen3_moe.py`
|
||||
- `vllm_kunlun/models/deepseek_v2.py`
|
||||
|
||||
**Additional Points:**
|
||||
|
||||
- Use `vllm.model_executor.layers.fused_moe.layer.FusedMoE`; Kunlun has replaced the underlying kernel via import hook.
|
||||
- MoE's `load_weights()` is more complex, requiring expert parameter mapping:
|
||||
|
||||
```python
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=config.n_routed_experts,
|
||||
)
|
||||
```
|
||||
|
||||
- Recommended environment variables:
|
||||
|
||||
```bash
|
||||
export KUNLUN_USE_MOE_FFN_BLOCK=True
|
||||
export XPU_USE_MOE_SORTED_THRES=120
|
||||
```
|
||||
|
||||
### 3.2 MLA (Multi-Latent Attention) Models (e.g., DeepSeek-V3)
|
||||
|
||||
**Reference File:** `vllm_kunlun/models/deepseek_v2.py`
|
||||
|
||||
**MLA Special Handling:**
|
||||
- KV compression dimensions: `kv_lora_rank`, `qk_nope_head_dim`, `qk_rope_head_dim`.
|
||||
- Platform layer automatically selects `FlashMLABackend`:
|
||||
|
||||
```python
|
||||
# vllm_kunlun/platforms/kunlun.py
|
||||
if use_mla:
|
||||
if use_sparse:
|
||||
return "vllm_kunlun.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
|
||||
return "vllm_kunlun.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
```
|
||||
|
||||
- `block_size` usually needs to be set to **64**.
|
||||
- Recommended setting: `export USE_ORI_ROPE=1`.
|
||||
|
||||
### 3.3 Multi-modal Models (e.g., Qwen2-VL, InternVL)
|
||||
|
||||
**Reference Files:**
|
||||
- `vllm_kunlun/models/qwen3_vl.py`
|
||||
- `vllm_kunlun/models/internvl.py`
|
||||
- `vllm_kunlun/models/interns1.py`
|
||||
|
||||
**Additional Components to Implement:**
|
||||
|
||||
| Component | Description |
|
||||
|------|------|
|
||||
| `SupportsMultiModal` Interface | Declares that the model supports multi-modal input |
|
||||
| Vision Encoder | Usually `InternVisionModel` or custom ViT |
|
||||
| Projector | Vision → Language mapping (e.g., MLP) |
|
||||
| `@MULTIMODAL_REGISTRY.register_processor(...)` | Register multi-modal processor |
|
||||
| `BaseMultiModalProcessor` | Handles multi-modal input |
|
||||
| `BaseProcessingInfo` | Handles processing info |
|
||||
| `BaseDummyInputsBuilder` | Dummy inputs for the profiling phase |
|
||||
|
||||
### 3.4 Hybrid Attention Models (e.g., Qwen3-Next)
|
||||
|
||||
**Reference File:** `vllm_kunlun/models/qwen3_next.py`
|
||||
|
||||
This model contains both **Linear Attention** and **Full Attention** layer types:
|
||||
|
||||
```python
|
||||
# Select different attention calculations based on layer_type
|
||||
if self.layer_type == "linear_attention":
|
||||
self.linear_attn(hidden_states=hidden_states, output=self_attention_output)
|
||||
elif self.layer_type == "full_attention":
|
||||
self.self_attn(hidden_states=hidden_states, output=self_attention_output, positions=positions)
|
||||
```
|
||||
|
||||
Note:
|
||||
- Linear Attention uses `GatedDeltaNet` or similar implementations.
|
||||
- Need to register custom `custom_op` (e.g., `vllm.gdn_attention`) for `splitting_ops` in `torch.compile`.
|
||||
|
||||
---
|
||||
|
||||
## IV. Quantized Model Adaptation
|
||||
|
||||
### 4.1 Supported Quantization Methods
|
||||
|
||||
| Quantization Method | Adaptation File | Status |
|
||||
|---------|---------|------|
|
||||
| **INT8 Dynamic (W8A8)** | `ops/quantization/kernels/kunlun_scale_mm.py` | ✅ Recommended |
|
||||
| **AWQ (INT4)** | `ops/quantization/awq.py` | ✅ Supported |
|
||||
| **GPTQ (INT4)** | `ops/quantization/gptq.py` | ✅ Supported |
|
||||
| **CompressedTensors (INT8 MoE)** | `ops/quantization/compressed_tensors/` | ✅ Supported |
|
||||
| **FP8** | — | ⚠️ Partial Support |
|
||||
| **bfloat16** | — | ⚠️ Double VRAM bug |
|
||||
|
||||
### 4.2 Special Handling for Quantization
|
||||
|
||||
Kunlun chips use the **max value** for scale calculation instead of vLLM's default absmax:
|
||||
|
||||
```python
|
||||
# ops/quantization/kernels/kunlun_scale_mm.py
|
||||
class KunlunScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
def process_weights_after_loading(self, layer):
|
||||
super().process_weights_after_loading(layer)
|
||||
# ⭐ Key: Multiply scale by 127.0 to convert to max format
|
||||
with torch.no_grad():
|
||||
getattr(layer, self.w_s_name).mul_(127.0)
|
||||
```
|
||||
|
||||
INT4 weights need to be **repacked** into the Kunlun layout order:
|
||||
|
||||
```python
|
||||
# AWQ repack example
|
||||
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
|
||||
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## V. Custom Operators (if new low-level Ops are needed)
|
||||
|
||||
If your model requires new low-level operators:
|
||||
|
||||
### 5.1 Wrap kunlun_ops calls in `_kunlun_ops.py`
|
||||
|
||||
```python
|
||||
# vllm_kunlun/ops/_kunlun_ops.py
|
||||
class KunlunOps:
|
||||
@staticmethod
|
||||
def my_new_op(input, weight, out):
|
||||
"""Call underlying kunlun_ops implementation"""
|
||||
kunlun_ops.my_new_op(input, weight, out=out)
|
||||
```
|
||||
|
||||
### 5.2 Register to vLLM in `_custom_ops.py`
|
||||
|
||||
Follow the **three-piece pattern**:
|
||||
|
||||
```python
|
||||
# vllm_kunlun/ops/_custom_ops.py
|
||||
|
||||
# 1. Define the actual implementation of the op
|
||||
def my_new_op_impl(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.empty_like(input)
|
||||
KunlunOps.my_new_op(input, weight, output)
|
||||
return output
|
||||
|
||||
# 2. Define fake tensor implementation (for torch.compile)
|
||||
def my_new_op_fake(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
return torch.empty_like(input)
|
||||
|
||||
# 3. Register
|
||||
direct_register_custom_op(
|
||||
op_name="my_new_op",
|
||||
op_func=my_new_op_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=my_new_op_fake,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## VI. Common Pitfalls Checklist
|
||||
|
||||
Before submitting a PR, please check each item:
|
||||
|
||||
- [ ] **Attention** uses `vllm_kunlun.ops.attention.layer.Attention`?
|
||||
- [ ] **Activation functions** use `vllm_kunlun.ops.activation.SiluAndMul`, etc.?
|
||||
- [ ] All submodules in `__init__()` have the `prefix` parameter passed?
|
||||
- [ ] `load_weights()` correctly handles weight name mapping (`stacked_params_mapping`)?
|
||||
- [ ] `@support_torch_compile` decorator is added to the main model class?
|
||||
- [ ] The first parameter of `ModelRegistry.register_model()` exactly matches `architectures` in `config.json`?
|
||||
- [ ] No use of `VLLM_USE_V1` environment variable for logic (deprecated, v0.15.1 is V1-only)?
|
||||
- [ ] Type annotations use `Optional[T]` instead of `T | None` (to avoid `infer_schema` failure)?
|
||||
- [ ] Quantized model scales are correctly multiplied by `127.0`?
|
||||
- [ ] Supports Pipeline Parallelism (using `PPMissingLayer`, `is_pp_missing_parameter`)?
|
||||
- [ ] Ran `pre-commit` format checks?
|
||||
- [ ] Commits use `-s` signature (DCO compliance)?
|
||||
|
||||
---
|
||||
|
||||
## VII. Reference Template Quick Look-up
|
||||
|
||||
| Model Type | Best Reference File | Features |
|
||||
|---------|------------|------|
|
||||
| Standard Dense LLM | `qwen3.py` | Simplest, recommended for beginners |
|
||||
| Dense LLM (Custom Embedding) | `seed_oss.py`, `internlm2.py` | Custom VocabParallelEmbedding |
|
||||
| MoE LLM | `qwen3_moe.py` | FusedMoE + EP + SharedExpert |
|
||||
| MLA + MoE (DeepSeek) | `deepseek_v2.py` | MLA attention + MoE + Indexer |
|
||||
| Hybrid Attention | `qwen3_next.py` | Linear + Full attention |
|
||||
| Multi-modal (VL) | `qwen3_vl.py`, `internvl.py` | ViT + Projector + LLM |
|
||||
| Speculative Decoding (MTP) | `deepseek_mtp.py` | Multi-Token Prediction |
|
||||
|
||||
---
|
||||
|
||||
## VIII. Debugging Tips
|
||||
|
||||
### 8.1 Startup Failure
|
||||
|
||||
- **`ModuleNotFoundError`**: Check if the import hook mapping table in `__init__.py` covers the corresponding module.
|
||||
- **`circular import`**: Check if your new code introduces heavy dependencies during the `register()` phase.
|
||||
- **`Model architecture XXX is not supported`**: Check if the first parameter of `register_model()` matches `config.json`.
|
||||
|
||||
### 8.2 Abnormal Output
|
||||
|
||||
- **Garbage output**: Compare with HF transformers output on CPU; likely an operator precision issue or weight loading mapping error.
|
||||
- **Repeated tokens**: Check if `rotary_embedding` is applied correctly and if the `is_neox_style` parameter is correct.
|
||||
- **Truncated output**: Check `max_model_len` settings and if KV cache is sufficient.
|
||||
|
||||
### 8.3 VRAM Issues
|
||||
|
||||
- Use `--dtype float16` (avoid bfloat16 due to double VRAM bug).
|
||||
- Set `VLLM_KUNLUN_ENABLE_INT8_BMM=1` (saves ~0.1GB).
|
||||
- Lower `--gpu-memory-utilization` (default is 0.9).
|
||||
- Use INT8 quantized models.
|
||||
|
||||
### 8.4 Weight Loading Failure
|
||||
|
||||
```python
|
||||
# Debugging method: Print parameter names for comparison
|
||||
params_dict = dict(self.named_parameters())
|
||||
print("=== Model params ===")
|
||||
for k in sorted(params_dict.keys()):
|
||||
print(f" {k}: {params_dict[k].shape}")
|
||||
|
||||
# Print in load_weights
|
||||
for name, loaded_weight in weights:
|
||||
if name not in params_dict:
|
||||
print(f" ⚠️ Skipped: {name}")
|
||||
```
|
||||
|
||||
### 8.5 Kunlun Graph Failure
|
||||
|
||||
Confirm that `splitting_ops` in `compilation-config` includes your attention op name:
|
||||
|
||||
```json
|
||||
{
|
||||
"splitting_ops": [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.unified_attention_with_output_kunlun",
|
||||
"vllm.sparse_attn_indexer_vllm_kunlun"
|
||||
],
|
||||
"cudagraph_mode": "PIECEWISE"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## IX. Environment Variables Cheat Sheet
|
||||
|
||||
```bash
|
||||
# === Required ===
|
||||
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # Specify Kunlun cards to use
|
||||
export VLLM_HOST_IP=$(hostname -i) # IP for distributed communication
|
||||
|
||||
# === Recommended ===
|
||||
export XMLIR_FORCE_USE_XPU_GRAPH=1 # Enable XPU Graph acceleration
|
||||
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false # Disable mock compile
|
||||
export XMLIR_CUDNN_ENABLED=1 # Enable cuDNN equivalent acceleration
|
||||
export XPU_USE_DEFAULT_CTX=1 # Default context
|
||||
export BKCL_FORCE_SYNC=1 # BKCL forced sync (multi-card stability)
|
||||
|
||||
# === Model Specific ===
|
||||
export USE_ORI_ROPE=1 # DeepSeek series uses original RoPE
|
||||
export XFT_USE_FAST_SWIGLU=1 # Fast SwiGLU activation
|
||||
export XPU_USE_FAST_SWIGLU=1 # Same as above (some versions)
|
||||
export XPU_USE_MOE_SORTED_THRES=120 # MoE sorting threshold
|
||||
export KUNLUN_USE_MOE_FFN_BLOCK=True # MoE FFN block optimization
|
||||
|
||||
# === Optional Tuning ===
|
||||
export VLLM_KUNLUN_ENABLE_INT8_BMM=1 # Enable INT8 BMM (saves ~0.1GB)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## X. PR Submission Standards
|
||||
|
||||
### 10.1 Branch Naming
|
||||
|
||||
```
|
||||
feature/add-my-new-model
|
||||
bugfix/fix-attention-output
|
||||
```
|
||||
|
||||
### 10.2 Commit Message Prefix
|
||||
|
||||
| Prefix | Description |
|
||||
|------|------|
|
||||
| `[Feature]` | New functionality / New model |
|
||||
| `[Bugfix]` | Bug fix |
|
||||
| `[CI/Build]` | CI / Build related |
|
||||
| `[Doc]` | Documentation update |
|
||||
| `[Misc]` | Others |
|
||||
|
||||
### 10.3 Before Submission
|
||||
|
||||
```bash
|
||||
# 1. Install pre-commit
|
||||
pre-commit install
|
||||
|
||||
# 2. Run checks
|
||||
pre-commit run --all-files
|
||||
|
||||
# 3. Signed commit (DCO compliance)
|
||||
git commit -s -m "[Feature] Add MyNewModel support for Kunlun"
|
||||
```
|
||||
|
||||
### 10.4 PR Checklist
|
||||
|
||||
- [ ] Code passes `pre-commit` checks.
|
||||
- [ ] Single-card offline inference test passed.
|
||||
- [ ] Multi-card TP test passed (if applicable).
|
||||
- [ ] Quantized model test passed (if applicable).
|
||||
- [ ] Updated `vllm_kunlun/models/__init__.py` registration.
|
||||
- [ ] Updated supported models list in README (if applicable).
|
||||
|
||||
---
|
||||
|
||||
## Appendix: Standard Startup Command Templates
|
||||
|
||||
### A. Standard Dense Model (Single Card)
|
||||
|
||||
```bash
|
||||
XPU_VISIBLE_DEVICES=0 \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 --port 8888 \
|
||||
--model /path/to/model \
|
||||
--trust-remote-code \
|
||||
--dtype float16 \
|
||||
--max-model-len 8192 \
|
||||
--block-size 64
|
||||
```
|
||||
|
||||
### B. MoE Model (8-card TP)
|
||||
|
||||
```bash
|
||||
XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
XMLIR_FORCE_USE_XPU_GRAPH=1 \
|
||||
KUNLUN_USE_MOE_FFN_BLOCK=True \
|
||||
XPU_USE_MOE_SORTED_THRES=120 \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 --port 8888 \
|
||||
--model /path/to/moe-model-int8 \
|
||||
--trust-remote-code \
|
||||
--dtype float16 \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 8 \
|
||||
--max_num_seqs 4 \
|
||||
--block-size 64 \
|
||||
--no-enable-chunked-prefill \
|
||||
--distributed-executor-backend mp \
|
||||
--no-enable-prefix-caching
|
||||
```
|
||||
|
||||
### C. DeepSeek-V3 (MLA + MoE, W8A8)
|
||||
|
||||
```bash
|
||||
XMLIR_ENABLE_MOCK_TORCH_COMPILE=false \
|
||||
USE_ORI_ROPE=1 \
|
||||
XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 --port 8806 \
|
||||
--model /path/to/DeepSeek-V3-w8a8 \
|
||||
--gpu-memory-utilization 0.98 \
|
||||
--trust-remote-code \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 8 \
|
||||
--dtype float16 \
|
||||
--max_num_seqs 4 \
|
||||
--block-size 64 \
|
||||
--no-enable-chunked-prefill \
|
||||
--distributed-executor-backend mp \
|
||||
--no-enable-prefix-caching
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
> 📝 **Document Maintenance**: If you have questions or suggestions, please provide feedback in [GitHub Issues](https://github.com/baidu/vLLM-Kunlun/issues).
|
||||
@@ -6,17 +6,23 @@ torch_xray is an operator precision analysis tool that can dump module-level inp
|
||||
|
||||
### 1.Download and install
|
||||
|
||||
***\*python3.10:\****
|
||||
**\*python3.12:\***
|
||||
|
||||
bos:/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl
|
||||
```
|
||||
pip install "https://klx-sdk-release-public.su.bcebos.com/torch_xray/release/2.0.3.0/torch_xray-2.0.3-cp312-cp312-linux_x86_64.whl"
|
||||
```
|
||||
|
||||
[https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/](https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-py3-none-any.whl)torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl
|
||||
**\*python3.10:\***
|
||||
|
||||
***\*python3.8:\****
|
||||
```
|
||||
pip install "https://klx-sdk-release-public.su.bcebos.com/torch_xray/release/2.0.3.0/torch_xray-2.0.3-cp310-cp310-linux_x86_64.whl"
|
||||
```
|
||||
|
||||
bos:/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-cp38-cp38-linux_x86_64.whl
|
||||
**\*python3.8:\***
|
||||
|
||||
[https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/](https://su.bcebos.com/klx-sdk-release-public/xpytorch/dev_kl3/torch_xray/latest/torch_xray-999.9.9-py3-none-any.whl)torch_xray-999.9.9-cp38-cp38-linux_x86_64.whl
|
||||
```
|
||||
pip install "https://klx-sdk-release-public.su.bcebos.com/torch_xray/release/2.0.3.0/torch_xray-2.0.3-cp38-cp38-linux_x86_64.whl"
|
||||
```
|
||||
|
||||
Note that the same installation package must be used when using it in different environments.
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ user_guide/release_notes
|
||||
:::{toctree}
|
||||
:caption: Developer Guide
|
||||
:maxdepth: 1
|
||||
developer_guide/developer_guide
|
||||
developer_guide/contribution/index
|
||||
developer_guide/feature_guide/index
|
||||
developer_guide/evaluation/index
|
||||
|
||||
@@ -75,55 +75,34 @@ cp vllm_kunlun/patches/eval_frame.py /root/miniconda/envs/vllm_kunlun_0.10.1.1/l
|
||||
## Choose to download customized xpytorch
|
||||
|
||||
### Install the KL3-customized build of PyTorch
|
||||
|
||||
```
|
||||
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-02T05%3A01%3A27Z%2F-1%2Fhost%2Ff3cf499234f82303891aed2bcb0628918e379a21e841a3fac6bd94afef491ff7
|
||||
(for the conda)
|
||||
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://baidu-kunlun-public.su.bcebos.com/baidu-kunlun-share/20260206/xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||
|
||||
#for conda
|
||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||
(for the uv)
|
||||
|
||||
#for uv
|
||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
|
||||
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
|
||||
```
|
||||
### Install the KL3-customized build of PyTorch (Only MIMO V2)
|
||||
```
|
||||
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1231/xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||
(for the conda)
|
||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||
(for the uv)
|
||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
|
||||
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g' setup.sh && bash setup.sh
|
||||
|
||||
```
|
||||
|
||||
### Install the KL3-customized build of PyTorch (Only DeepSeek-V3.2-Exp-w8a8)
|
||||
```
|
||||
wget -O xpytorch-cp310-torch251-ubuntu2004-x64.run https://aihc-private-hcd.bj.bcebos.com/v1/vllm-kunlun-ds/xpytorch-cp310-torch251-ubuntu2004-x64.run?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2026-02-03T01%3A59%3A40Z%2F-1%2Fhost%2Ffc4b6f5b83c2fde70d48fdfc23c40c396efc9cb3c36d6f811fdca5f109073321
|
||||
(for the conda)
|
||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run
|
||||
(for the uv)
|
||||
bash xpytorch-cp310-torch251-ubuntu2004-x64.run --noexec --target xpytorch_unpack && cd xpytorch_unpack/ && \
|
||||
mv torch_xray-999.9.9-cp310-cp310-linux_x86_64.whl torch_xray-2.0.3-cp310-cp310-linux_x86_64.whl && \
|
||||
sed -i 's/pip/uv pip/g; s/CONDA_PREFIX/VIRTUAL_ENV/g; s/torch_xray-999.9.9/torch_xray-2.0.3/' setup.sh && bash setup.sh
|
||||
```
|
||||
## Choose to download customized ops
|
||||
|
||||
### Install custom ops
|
||||
|
||||
```
|
||||
uv pip install "https://baidu-kunlun-public.su.bcebos.com/v1/baidu-kunlun-share/1130/xtorch_ops-0.1.2209%2B6752ad20-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKypXxBzU7gg4Mk4K4c6OYR%2F2025-12-05T06%3A18%3A00Z%2F-1%2Fhost%2F14936c2b7e7c557c1400e4c467c79f7a9217374a7aa4a046711ac4d948f460cd"
|
||||
```
|
||||
### Install custom ops (Only MIMO V2)
|
||||
```
|
||||
uv pip install "https://vllm-ai-models.bj.bcebos.com/v1/vLLM-Kunlun/ops/swa/xtorch_ops-0.1.2109%252B523cb26d-cp310-cp310-linux_x86_64.whl"
|
||||
```
|
||||
### Install custom ops (Only DeepSeek-V3.2-Exp-w8a8)
|
||||
```
|
||||
uv pip install "https://klx-sdk-release-public.su.bcebos.com/kunlun2aiak_output/1215/xtorch_ops-0.1.2263%2Bc030eebd-cp310-cp310-linux_x86_64.whl"
|
||||
uv pip install "https://baidu-kunlun-public.su.bcebos.com/baidu-kunlun-share/20260206/kunlun_ops-0.1.45%2Bbac5499e-cp310-cp310-linux_x86_64.whl"
|
||||
```
|
||||
|
||||
## Install the KLX3 custom Triton build
|
||||
|
||||
```
|
||||
uv pip install "https://cce-ai-models.bj.bcebos.com/v1/vllm-kunlun-0.11.0/triton-3.0.0%2Bb2cde523-cp310-cp310-linux_x86_64.whl"
|
||||
```
|
||||
|
||||
## Install the AIAK custom ops library
|
||||
|
||||
```
|
||||
uv pip install "https://vllm-ai-models.bj.bcebos.com/XSpeedGate-whl/release_merge/20260130_152557/xspeedgate_ops-0.0.0%2Be5cdcbe-cp310-cp310-linux_x86_64.whl?authorization=bce-auth-v1%2FALTAKhvtgrTA8US5LIc8Vbl0mP%2F2026-01-30T10%3A33%3A32Z%2F2592000%2Fhost%2F3c13d67cc61d0df7538c198f5c32422f3b034068a40eef43cb51b079cc6f0555" --force-reinstall
|
||||
```
|
||||
|
||||
@@ -8,5 +8,7 @@ single_xpu_Qwen3-VL-32B
|
||||
single_xpu_InternVL2_5-26B
|
||||
multi_xpu_Qwen2.5-VL-32B
|
||||
multi_xpu_GLM-4.5
|
||||
multi_xpu_GLM-5-W8A8-INT8
|
||||
multi_xpu_DeepSeek-V3.2-Exp-w8a8
|
||||
multi_xpu_Qwen3-Coder-480B-A35B(W8A8)
|
||||
:::
|
||||
|
||||
@@ -7,6 +7,7 @@ Setup environment using container:
|
||||
Please follow the [installation.md](../installation.md) document to set up the environment first.
|
||||
|
||||
Create a container
|
||||
|
||||
```bash
|
||||
# !/bin/bash
|
||||
# rundocker.sh
|
||||
@@ -36,13 +37,16 @@ docker run -itd ${DOCKER_DEVICE_CONFIG} \
|
||||
### Preparation Weight
|
||||
|
||||
- Pull DeepSeek-V3.2-Exp-w8a8-int8 weights
|
||||
|
||||
```
|
||||
wget -O DeepSeek-V3.2-Exp-w8a8-int8.tar.gz https://aihc-private-hcd.bj.bcebos.com/v1/LLM/DeepSeek/DeepSeek-V3.2-Exp-w8a8-int8.tar.gz?authorization=bce-auth-v1%2FALTAKvz6x4eqcmSsKjQxq3vZdB%2F2025-12-24T06%3A07%3A10Z%2F-1%2Fhost%2Fa324bf469176934a05f75d3acabc3c1fb891be150f43fb1976e65b7ec68733db
|
||||
```
|
||||
|
||||
- Ensure that the field "quantization_config" is included.If not, deployment will result in an OOM (Out of Memory) error.
|
||||
|
||||
vim model/DeepSeek-V3.2-Exp-w8a8-int8/config.json
|
||||
```config.json
|
||||
|
||||
```json
|
||||
"quantization_config": {
|
||||
"config_groups": {
|
||||
"group_0": {
|
||||
@@ -108,7 +112,7 @@ export CUDA_GRAPH_OPTIMIZE_STREAM=1 && \
|
||||
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false && \
|
||||
export XPU_USE_MOE_SORTED_THRES=1 && \
|
||||
export USE_ORI_ROPE=1 && \
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 \
|
||||
@@ -129,9 +133,9 @@ python -m vllm.entrypoints.openai.api_server \
|
||||
--compilation-config '{"splitting_ops":["vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.unified_attention_with_output_kunlun",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
|
||||
92
docs/source/tutorials/multi_xpu_GLM-5-W8A8-INT8.md
Normal file
92
docs/source/tutorials/multi_xpu_GLM-5-W8A8-INT8.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Multi XPU (GLM-5-W8A8-INT8)
|
||||
|
||||
## Run vllm-kunlun on Multi XPU
|
||||
|
||||
Setup environment using container:
|
||||
|
||||
Please follow the [installation.md](../installation.md) document to set up the environment first.
|
||||
|
||||
Create a container
|
||||
```bash
|
||||
# !/bin/bash
|
||||
# rundocker.sh
|
||||
XPU_NUM=8
|
||||
DOCKER_DEVICE_CONFIG=""
|
||||
if [ $XPU_NUM -gt 0 ]; then
|
||||
for idx in $(seq 0 $((XPU_NUM-1))); do
|
||||
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpu${idx}:/dev/xpu${idx}"
|
||||
done
|
||||
DOCKER_DEVICE_CONFIG="${DOCKER_DEVICE_CONFIG} --device=/dev/xpuctrl:/dev/xpuctrl"
|
||||
fi
|
||||
|
||||
export build_image="xxx"
|
||||
|
||||
docker run -itd ${DOCKER_DEVICE_CONFIG} \
|
||||
--net=host \
|
||||
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
|
||||
--tmpfs /dev/shm:rw,nosuid,nodev,exec,size=32g \
|
||||
--cap-add=SYS_PTRACE \
|
||||
-v /home/users/vllm-kunlun:/home/vllm-kunlun \
|
||||
-v /usr/local/bin/xpu-smi:/usr/local/bin/xpu-smi \
|
||||
--name "$1" \
|
||||
-w /workspace \
|
||||
"$build_image" /bin/bash
|
||||
```
|
||||
|
||||
### Preparation Weight
|
||||
|
||||
- Pull GLM-5-W8A8-INT8 weights
|
||||
```
|
||||
wget -O GLM-5-W8A8-INT8-Dynamic.tar.gz https://aihc-private-hcd.bj.bcebos.com/LLM/AICapX-Quant-Models/GLM-5-W8A8-INT8-Dynamic.tar.gz
|
||||
```
|
||||
|
||||
### Online Serving on Multi XPU
|
||||
|
||||
Start the vLLM server on multi XPU:
|
||||
|
||||
```bash
|
||||
unset XPU_DUMMY_EVENT && \
|
||||
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 && \
|
||||
export XMLIR_CUDNN_ENABLED=1 && \
|
||||
export XPU_USE_DEFAULT_CTX=1 && \
|
||||
export XMLIR_FORCE_USE_XPU_GRAPH=1 && \
|
||||
export XMLIR_ENABLE_FAST_FC=1 && \
|
||||
export XPU_USE_FAST_SWIGLU=1 && \
|
||||
export CUDA_GRAPH_OPTIMIZE_STREAM=1 && \
|
||||
export XMLIR_ENABLE_MOCK_TORCH_COMPILE=false && \
|
||||
export XPU_USE_MOE_SORTED_THRES=1 && \
|
||||
export USE_ORI_ROPE=1 && \
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 \
|
||||
--port 8806 \
|
||||
--model GLM-5-W8A8-INT8-Dynamic \
|
||||
--gpu-memory-utilization 0.97 \
|
||||
--trust-remote-code \
|
||||
--max-model-len 32768 \
|
||||
--tensor-parallel-size 8 \
|
||||
--dtype bfloat16 \
|
||||
--max_num_seqs 8 \
|
||||
--max_num_batched_tokens 8192 \
|
||||
--block-size 64 \
|
||||
--no-enable-chunked-prefill \
|
||||
--distributed-executor-backend mp \
|
||||
--disable-log-requests \
|
||||
--no-enable-prefix-caching \
|
||||
--kv-cache-dtype bfloat16 \
|
||||
--compilation-config '{
|
||||
"splitting_ops":[
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.unified_attention_with_output_kunlun",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.sparse_attn_indexer",
|
||||
"vllm.sparse_attn_indexer_vllm_kunlun"
|
||||
]}'
|
||||
```
|
||||
@@ -86,8 +86,10 @@ if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
:::::
|
||||
If you run this script successfully, you can see the info shown below:
|
||||
|
||||
```bash
|
||||
==================================================
|
||||
Input content: [{'role': 'user', 'content': [{'type': 'text', 'text': '你好!你是谁?'}]}]
|
||||
@@ -95,9 +97,11 @@ Model response:
|
||||
你好!我是一个由人工智能驱动的助手,旨在帮助回答问题、提供信息和解决日常问题。请问有什么我可以帮助你的?
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Online Serving on Single XPU
|
||||
Start the vLLM server on a single XPU:
|
||||
```bash
|
||||
|
||||
```text
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 \
|
||||
--port 9988 \
|
||||
@@ -114,25 +118,29 @@ python -m vllm.entrypoints.openai.api_server \
|
||||
--no-enable-chunked-prefill \
|
||||
--distributed-executor-backend mp \
|
||||
--served-model-name InternVL2_5-26B \
|
||||
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
|
||||
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.unified_attention_with_output_kunlun",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.sparse_attn_indexer"]}
|
||||
#Version 0.11.0
|
||||
#Version 0.11.0
|
||||
```
|
||||
|
||||
If your service start successfully, you can see the info shown below:
|
||||
|
||||
```bash
|
||||
(APIServer pid=157777) INFO: Started server process [157777]
|
||||
(APIServer pid=157777) INFO: Waiting for application startup.
|
||||
(APIServer pid=157777) INFO: Application startup complete.
|
||||
```
|
||||
|
||||
Once your server is started, you can query the model with input prompts:
|
||||
|
||||
```bash
|
||||
curl http://localhost:9988/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -145,17 +153,23 @@ curl http://localhost:9988/v1/completions \
|
||||
"top_k": 50
|
||||
}'
|
||||
```
|
||||
|
||||
If you query the server successfully, you can see the info shown below (client):
|
||||
|
||||
```bash
|
||||
{"id":"cmpl-23a24afd616d4a47910aeeccb20921ed","object":"text_completion","created":1768891222,"model":"InternVL2_5-26B","choices":[{"index":0,"text":" 你有什么问题吗?\n\n你好!我是书生·AI,很高兴能与你交流。请问有什么我可以帮助你的吗?无论是解答问题、提供信息还是其他方面的帮助,我都会尽力而为。请告诉我你的需求。","logprobs":null,"finish_reason":"stop","stop_reason":92542,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":6,"total_tokens":53,"completion_tokens":47,"prompt_tokens_details":null},"kv_transfer_params":null}
|
||||
```
|
||||
|
||||
Logs of the vllm server:
|
||||
|
||||
```bash
|
||||
(APIServer pid=161632) INFO: 127.0.0.1:56708 - "POST /v1/completions HTTP/1.1" 200 OK
|
||||
(APIServer pid=161632) INFO 01-20 14:40:25 [loggers.py:127] Engine 000: Avg prompt throughput: 0.6 tokens/s, Avg generation throughput: 4.6 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
|
||||
(APIServer pid=161632) INFO 01-20 14:40:35 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
|
||||
```
|
||||
|
||||
Input an image for testing.Here,a python script is used:
|
||||
|
||||
```python
|
||||
import requests
|
||||
import base64
|
||||
@@ -193,13 +207,17 @@ payload = {
|
||||
response = requests.post(API_URL, json=payload)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
If you query the server successfully, you can see the info shown below (client):
|
||||
|
||||
```bash
|
||||
{'id': 'chatcmpl-9aeab6044795458da04f2fdcf1d0445d', 'object': 'chat.completion', 'created': 1768891349, 'model': 'InternVL2_5-26B', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': '你好!这张图片上有一个黄色的笑脸表情符号,双手合十,旁边写着“Hugging Face”。这个表情符号看起来很开心,似乎在表示拥抱或欢迎。', 'refusal': None, 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [], 'reasoning_content': None}, 'logprobs': None, 'finish_reason': 'stop', 'stop_reason': 92542, 'token_ids': None}], 'service_tier': None, 'system_fingerprint': None, 'usage': {'prompt_tokens': 790, 'total_tokens': 827, 'completion_tokens': 37, 'prompt_tokens_details': None}, 'prompt_logprobs': None, 'prompt_token_ids': None, 'kv_transfer_params': None}
|
||||
```
|
||||
|
||||
Logs of the vllm server:
|
||||
|
||||
```bash
|
||||
(APIServer pid=161632) INFO: 127.0.0.1:58686 - "POST /v1/chat/completions HTTP/1.1" 200 OK
|
||||
(APIServer pid=161632) INFO 01-20 14:42:35 [loggers.py:127] Engine 000: Avg prompt throughput: 79.0 tokens/s, Avg generation throughput: 3.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
|
||||
(APIServer pid=161632) INFO 01-20 14:42:45 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
|
||||
```
|
||||
```
|
||||
|
||||
@@ -85,19 +85,23 @@ if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
:::::
|
||||
If you run this script successfully, you can see the info shown below:
|
||||
|
||||
```bash
|
||||
==================================================
|
||||
Input content: [{'role': 'user', 'content': [{'type': 'text', 'text': 'tell a joke'}]}]
|
||||
Model response:
|
||||
Why don’t skeletons fight each other?
|
||||
Why don’t skeletons fight each other?
|
||||
Because they don’t have the guts! 🦴😄
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Online Serving on Single XPU
|
||||
Start the vLLM server on a single XPU:
|
||||
```bash
|
||||
|
||||
```text
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--host 0.0.0.0 \
|
||||
--port 9988 \
|
||||
@@ -114,25 +118,29 @@ python -m vllm.entrypoints.openai.api_server \
|
||||
--no-enable-chunked-prefill \
|
||||
--distributed-executor-backend mp \
|
||||
--served-model-name Qwen3-VL-32B \
|
||||
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
|
||||
--compilation-config '{"splitting_ops": ["vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.unified_attention_with_output_kunlun",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.sparse_attn_indexer"]}
|
||||
#Version 0.11.0
|
||||
#Version 0.11.0
|
||||
```
|
||||
|
||||
If your service start successfully, you can see the info shown below:
|
||||
|
||||
```bash
|
||||
(APIServer pid=109442) INFO: Started server process [109442]
|
||||
(APIServer pid=109442) INFO: Waiting for application startup.
|
||||
(APIServer pid=109442) INFO: Application startup complete.
|
||||
```
|
||||
|
||||
Once your server is started, you can query the model with input prompts:
|
||||
|
||||
```bash
|
||||
curl http://localhost:9988/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -143,11 +151,15 @@ curl http://localhost:9988/v1/completions \
|
||||
"temperature": 0
|
||||
}'
|
||||
```
|
||||
|
||||
If you query the server successfully, you can see the info shown below (client):
|
||||
|
||||
```bash
|
||||
{"id":"cmpl-4f61fe821ff34f23a91baade5de5103e","object":"text_completion","created":1768876583,"model":"Qwen3-VL-32B","choices":[{"index":0,"text":" 你好!我是通义千问,是阿里云研发的超大规模语言模型。我能够回答问题、创作文字、编程等,还能根据你的需求进行多轮对话。有什么我可以帮你的吗?😊\n\n(温馨提示:我是一个AI助手,虽然我尽力提供准确和有用的信息,但请记得在做重要决策时,最好结合专业意见或进一步核实信息哦!)","logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":5,"total_tokens":90,"completion_tokens":85,"prompt_tokens_details":null},"kv_transfer_params":null}
|
||||
```
|
||||
|
||||
Logs of the vllm server:
|
||||
|
||||
```bash
|
||||
(APIServer pid=109442) INFO: 127.0.0.1:19962 - "POST /v1/completions HTTP/1.1" 200 OK
|
||||
(APIServer pid=109442) INFO 01-20 10:36:28 [loggers.py:127] Engine 000: Avg prompt throughput: 0.5 tokens/s, Avg generation throughput: 8.5 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
|
||||
@@ -155,7 +167,9 @@ Logs of the vllm server:
|
||||
(APIServer pid=109442) INFO 01-20 10:43:23 [chat_utils.py:560] Detected the chat template content format to be 'openai'. You can set `--chat-template-content-format` to override this.
|
||||
(APIServer pid=109442) INFO 01-20 10:43:28 [loggers.py:127] Engine 000: Avg prompt throughput: 9.0 tokens/s, Avg generation throughput: 6.9 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.5%, Prefix cache hit rate: 0.0%
|
||||
```
|
||||
|
||||
Input an image for testing.Here,a python script is used:
|
||||
|
||||
```python
|
||||
import requests
|
||||
import base64
|
||||
@@ -191,11 +205,15 @@ payload = {
|
||||
response = requests.post(API_URL, json=payload)
|
||||
print(response.json())
|
||||
```
|
||||
|
||||
If you query the server successfully, you can see the info shown below (client):
|
||||
|
||||
```bash
|
||||
{'id': 'chatcmpl-4b42fe46f2c84991b0af5d5e1ffad9ba', 'object': 'chat.completion', 'created': 1768877003, 'model': 'Qwen3-VL-32B', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': '你好!这张图片展示的是“Hugging Face”的标志。\n\n图片左侧是一个黄色的圆形表情符号(emoji),它有着圆圆的眼睛、张开的嘴巴露出微笑,双手合拢在脸颊两侧,做出一个拥抱或欢迎的姿态,整体传达出友好、温暖和亲切的感觉。\n\n图片右侧是黑色的英文文字“Hugging Face”,字体简洁现代,与左侧的表情符号相呼应。\n\n整个标志设计简洁明了,背景为纯白色,突出了标志本身。这个标志属于Hugging Face公司,它是一家知名的开源人工智能公司,尤其在自然语言处理(NLP)领域以提供预训练模型(如Transformers库)和模型托管平台而闻名。\n\n整体来看,这个标志通过可爱的表情符号和直白的文字,成功传达了公司“拥抱”技术、开放共享、友好的品牌理念。', 'refusal': None, 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [], 'reasoning_content': None}, 'logprobs': None, 'finish_reason': 'stop', 'stop_reason': None, 'token_ids': None}], 'service_tier': None, 'system_fingerprint': None, 'usage': {'prompt_tokens': 90, 'total_tokens': 266, 'completion_tokens': 176, 'prompt_tokens_details': None}, 'prompt_logprobs': None, 'prompt_token_ids': None, 'kv_transfer_params': None}
|
||||
```
|
||||
|
||||
Logs of the vllm server:
|
||||
|
||||
```bash
|
||||
(APIServer pid=109442) INFO: 127.0.0.1:26854 - "POST /v1/chat/completions HTTP/1.1" 200 OK
|
||||
(APIServer pid=109442) INFO 01-20 10:43:38 [loggers.py:127] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 10.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
|
||||
|
||||
@@ -2,14 +2,36 @@
|
||||
|
||||
## Generative Models
|
||||
|
||||
| Model | Support | W8A8 | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Piecewise Kunlun Graph |
|
||||
| :------------ | :------ | :--- | :--- | :-------------- | :-------------- | :------------ | :--------------------- |
|
||||
| Qwen3 | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen3-Moe | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Qwen3-Next | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Deepseek v3.2 | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| Model | Support | INT8(W8A8) | AWQ(W4A16) | GPTQ(WNA16) | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Kunlun Graph |
|
||||
| :------------ | :-----: | :--------: | :--------: | :---------: | :---: | :-------------: | :-------------: | :-----------: | :----------: |
|
||||
| Qwen2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen2.5 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen3 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen3-Moe | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
|
||||
| Qwen3-Next | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
|
||||
| MiMo-V2-Flash | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| Llama2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Llama3 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Llama3.1 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| gpt-oss | ✅ | ✅ | ✅ | ✅ | | ✅ | | | |
|
||||
| GLM4.5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| GLM4.5Air | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| GLM4.7 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| GLM5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| Kimi-K2 | ✅ | - | ✅ | - | | ✅ | | ✅ | ✅ |
|
||||
| DeepSeek-R1 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| DeepSeek-V3 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| DeepSeek-V3.2 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
|
||||
## Multimodal Language Models
|
||||
| Model | Support | W8A8 | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Piecewise Kunlun Graph |
|
||||
| :------- | :------ | :--- | :--- | :-------------- | :-------------- | :------------ | :--------------------- |
|
||||
| Qwen3-VL | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
|
||||
| Model | Support | INT8(W8A8) | AWQ(W4A16) | GPTQ(WNA16) | LoRA | Tensor Parallel | Expert Parallel | Data Parallel | Kunlun Graph |
|
||||
| :------------- | :-----: | :--------: | :--------: | :---------: | :---: | :-------------: | :-------------: | :-----------: | :----------: |
|
||||
| Qwen2-VL | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen2.5-VL | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen3-VL | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
|
||||
| Qwen3-VL-MoE | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
|
||||
| Qwen3-Omni-MoE | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ |
|
||||
| InternVL-2.5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| InternVL-3.5 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
| InternS1 | ✅ | ✅ | ✅ | ✅ | | ✅ | | ✅ | ✅ |
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
unset XPU_DUMMY_EVENT
|
||||
export XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
export XFT_USE_FAST_SWIGLU=1 #使用快速swiglu实现
|
||||
export XPU_USE_FAST_SWIGLU=1 #使用moe算子中快速swiglu实现
|
||||
export XMLIR_CUDNN_ENABLED=1
|
||||
|
||||
@@ -56,6 +56,16 @@ def register():
|
||||
"""Register the Kunlun platform"""
|
||||
from .utils import redirect_output
|
||||
from .vllm_utils_wrapper import direct_register_custom_op, patch_annotations_for_schema
|
||||
|
||||
# Change for GLM5
|
||||
if "vllm.transformers_utils.config" in sys.modules:
|
||||
from .transformer_utils.config import _XPU_CONFIG_REGISTRY
|
||||
sys.modules["vllm.transformers_utils.config"]._CONFIG_REGISTRY = _XPU_CONFIG_REGISTRY
|
||||
|
||||
import vllm.config.model as model_module
|
||||
from .config.model import is_deepseek_mla
|
||||
model_module.ModelConfig.is_deepseek_mla = property(is_deepseek_mla)
|
||||
|
||||
import_hook()
|
||||
return "vllm_kunlun.platforms.kunlun.KunlunPlatform"
|
||||
|
||||
|
||||
0
vllm_kunlun/config/__init__.py
Normal file
0
vllm_kunlun/config/__init__.py
Normal file
22
vllm_kunlun/config/model.py
Normal file
22
vllm_kunlun/config/model.py
Normal file
@@ -0,0 +1,22 @@
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in (
|
||||
"deepseek_v2",
|
||||
"deepseek_v3",
|
||||
"deepseek_v32",
|
||||
"deepseek_mtp",
|
||||
"kimi_k2",
|
||||
"longcat_flash",
|
||||
"glm_moe_dsa",
|
||||
):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == "eagle":
|
||||
# if the model is an EAGLE module, check for the
|
||||
# underlying architecture
|
||||
return (
|
||||
self.hf_text_config.model.model_type
|
||||
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
)
|
||||
return False
|
||||
@@ -3,91 +3,113 @@ from vllm import ModelRegistry
|
||||
|
||||
def register_model():
|
||||
# from .demo_model import DemoModel # noqa: F401
|
||||
from .qwen2_vl import Qwen2VLForConditionalGeneration #noqa: F401
|
||||
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration #noqa: F401
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM #noqa: F401
|
||||
from .qwen3_vl import Qwen3VLForConditionalGeneration
|
||||
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration
|
||||
from .qwen3_omni_moe_thinker import Qwen3OmniMoeThinkerForConditionalGeneration
|
||||
from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration # noqa: F401
|
||||
from .qwen2_vl import Qwen2VLForConditionalGeneration # noqa: F401
|
||||
from .qwen3_moe import Qwen3MoeForCausalLM # noqa: F401
|
||||
from .qwen3_omni_moe_thinker import ( # noqa: F401
|
||||
Qwen3OmniMoeThinkerForConditionalGeneration,
|
||||
)
|
||||
from .qwen3_vl import Qwen3VLForConditionalGeneration # noqa: F401
|
||||
from .qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration # noqa: F401
|
||||
|
||||
# from .llama4 import Llama4ForCausalLM #noqa: F401
|
||||
# from .mllama4 import Llama4ForConditionalGeneration #noqa: F401
|
||||
# from .deepseek_v2 import KunlunDeepseekV2MoE
|
||||
|
||||
# ModelRegistry.register_model(
|
||||
# "DemoModel",
|
||||
# "vllm_kunlun.model_executor.models.demo_model:DemoModel")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen2_vl:Qwen2VLForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen2_5_vl:Qwen2_5_VLForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3ForCausalLM",
|
||||
"vllm_kunlun.models.qwen3:Qwen3ForCausalLM")
|
||||
"Qwen3ForCausalLM", "vllm_kunlun.models.qwen3:Qwen3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM")
|
||||
"Qwen3MoeForCausalLM", "vllm_kunlun.models.qwen3_moe:Qwen3MoeForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM")
|
||||
"Qwen3NextForCausalLM", "vllm_kunlun.models.qwen3_next:Qwen3NextForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
"Qwen3NextMTP", "vllm_kunlun.models.qwen3_next_mtp:Qwen3NextMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternLM2ForCausalLM",
|
||||
"vllm_kunlun.models.internlm2:InternLM2ForCausalLM")
|
||||
|
||||
"GlmForCausalLM", "vllm_kunlun.models.glm:GlmForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternVLChatModel",
|
||||
"vllm_kunlun.models.internvl:InternVLChatModel")
|
||||
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternLM2ForCausalLM", "vllm_kunlun.models.internlm2:InternLM2ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternVLChatModel", "vllm_kunlun.models.internvl:InternVLChatModel"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"InternS1ForConditionalGeneration",
|
||||
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration")
|
||||
|
||||
"vllm_kunlun.models.interns1:InternS1ForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration")
|
||||
|
||||
"vllm_kunlun.models.qwen3_vl:Qwen3VLForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3VLMoeForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen3_vl_moe:Qwen3VLMoeForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3OmniMoeForConditionalGeneration",
|
||||
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration")
|
||||
"vllm_kunlun.models.qwen3_omni_moe_thinker:Qwen3OmniMoeThinkerForConditionalGeneration",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"SeedOssForCausalLM",
|
||||
"vllm_kunlun.models.seed_oss:SeedOssForCausalLM")
|
||||
"SeedOssForCausalLM", "vllm_kunlun.models.seed_oss:SeedOssForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"MiMoV2FlashForCausalLM",
|
||||
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM")
|
||||
"vllm_kunlun.models.mimo_v2_flash:MiMoV2FlashForCausalLM",
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GptOssForCausalLM",
|
||||
"vllm_kunlun.models.gpt_oss:GptOssForCausalLM")
|
||||
"GptOssForCausalLM", "vllm_kunlun.models.gpt_oss:GptOssForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
||||
"DeepseekV3ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM")
|
||||
|
||||
"DeepseekV32ForCausalLM", "vllm_kunlun.models.deepseek_v2:DeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_kunlun.models.deepseek_mtp:DeepSeekMTP")
|
||||
"DeepSeekMTPModel", "vllm_kunlun.models.deepseek_mtp:DeepSeekMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"GlmMoeDsaForCausalLM", "vllm_kunlun.models.deepseek_v2:GlmMoeDsaForCausalLM"
|
||||
)
|
||||
|
||||
|
||||
def register_quant_method():
|
||||
"""to do"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
303
vllm_kunlun/models/qwen3_next_mtp.py
Normal file
303
vllm_kunlun/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Qwen3Next MTP model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
|
||||
from .qwen3_next import Qwen3NextDecoderLayer, Qwen3NextRMSNorm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
config: Qwen3NextConfig = model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.fc = ColumnParallelLinear(
|
||||
self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
gather_output=True,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc",
|
||||
)
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
Qwen3NextDecoderLayer(
|
||||
vllm_config,
|
||||
layer_type="full_attention",
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(self.num_mtp_layers)
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
assert hidden_states.shape[-1] == inputs_embeds.shape[-1]
|
||||
inputs_embeds = self.pre_fc_norm_embedding(inputs_embeds)
|
||||
hidden_states = self.pre_fc_norm_hidden(hidden_states)
|
||||
hidden_states = torch.cat([inputs_embeds, hidden_states], dim=-1)
|
||||
hidden_states = self.fc(hidden_states)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
current_step_idx = spec_step_idx % self.num_mtp_layers
|
||||
hidden_states, residual = self.layers[current_step_idx](
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if (
|
||||
name.endswith(".bias") or name.endswith("_bias")
|
||||
) and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Qwen3NextMTP(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["up_proj", "down_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert (
|
||||
not cache_config.enable_prefix_caching
|
||||
), "Qwen3NextMTP currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Qwen3NextMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "mtp")
|
||||
)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size
|
||||
)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, hidden_states, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.logits_processor(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
shared_weight_names = ["embed_tokens", "lm_head"]
|
||||
|
||||
def remap_weight_names(weights):
|
||||
for name, weight in weights:
|
||||
if name.startswith("mtp."):
|
||||
name = name.replace("mtp.", "model.")
|
||||
elif not any(key in name for key in shared_weight_names):
|
||||
continue
|
||||
yield name, weight
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(remap_weight_names(weights))
|
||||
@@ -85,7 +85,7 @@ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
from einops import repeat
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -16,33 +16,34 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""kunlun custom op entry"""
|
||||
import torch_xmlir
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import os
|
||||
from typing import Optional, List, Dict
|
||||
import vllm.envs as envs
|
||||
import os
|
||||
import ctypes
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import xtorch_ops
|
||||
logger.info(f"Load custom ops library success!")
|
||||
import cocopod # noqa
|
||||
import kunlun_ops
|
||||
|
||||
logger.info("Load custom ops library success!")
|
||||
except ImportError as e:
|
||||
logger.warning("Import error msg: %s", e.msg)
|
||||
|
||||
|
||||
_per_token_smooth_quant = True
|
||||
|
||||
|
||||
def is_per_token_smooth_quant():
|
||||
""" is per token smooth quant """
|
||||
"""is per token smooth quant"""
|
||||
return _per_token_smooth_quant
|
||||
|
||||
|
||||
class KunlunOps:
|
||||
"""KunlunOps"""
|
||||
|
||||
# Attention ops
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
@@ -67,11 +68,11 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV1 """
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV1"""
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -81,7 +82,7 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128
|
||||
vo_head_dim=128,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -110,11 +111,11 @@ class KunlunOps:
|
||||
blocksparse_vert_stride,
|
||||
blocksparse_block_size,
|
||||
blocksparse_head_sliding_step,
|
||||
alibi_sqrt=False
|
||||
):
|
||||
""" PagedAttentionV2 """
|
||||
alibi_sqrt=False,
|
||||
):
|
||||
"""PagedAttentionV2"""
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -124,31 +125,28 @@ class KunlunOps:
|
||||
is_context=is_context,
|
||||
is_causal=True,
|
||||
out=output,
|
||||
vo_head_dim=128
|
||||
vo_head_dim=128,
|
||||
)
|
||||
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def silu_and_mul(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" silu and mul """
|
||||
xtorch_ops.silu_and_mul(
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor):
|
||||
"""silu and mul"""
|
||||
kunlun_ops.silu_and_mul(
|
||||
x,
|
||||
axis=-1,
|
||||
turn=True,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Activation ops
|
||||
@staticmethod
|
||||
def quick_gelu(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" quick gelu """
|
||||
xtorch_ops.quick_gelu(
|
||||
def quick_gelu(out: torch.Tensor, x: torch.Tensor):
|
||||
"""quick gelu"""
|
||||
kunlun_ops.quick_gelu(
|
||||
x,
|
||||
out=out,
|
||||
)
|
||||
)
|
||||
|
||||
# Layernorm
|
||||
@staticmethod
|
||||
@@ -159,9 +157,7 @@ class KunlunOps:
|
||||
epsilon,
|
||||
):
|
||||
"""rms_norm"""
|
||||
xtorch_ops.rmsnorm(
|
||||
x, weight.to(torch.float32), epsilon, out=out
|
||||
)
|
||||
kunlun_ops.rmsnorm(x, weight.to(torch.float32), epsilon, out=out)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(
|
||||
@@ -172,97 +168,61 @@ class KunlunOps:
|
||||
):
|
||||
"""fused_add_rms_norm"""
|
||||
output = torch.empty_like(x)
|
||||
xtorch_ops.add_rmsnorm(
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x, residual, weight.to(torch.float32), epsilon, out=output
|
||||
)
|
||||
fused_input = x + residual
|
||||
residual.copy_(fused_input, non_blocking=True)
|
||||
x.copy_(output)
|
||||
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
positions, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
|
||||
num_tokens = query_x.shape[0]
|
||||
num_heads = query_x.shape[1] // head_size
|
||||
num_kv_heads = key_x.shape[1] // head_size
|
||||
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style)
|
||||
|
||||
query_x = query_x.view(num_tokens, num_heads * head_size)
|
||||
key_x = key_x.view(num_tokens, num_kv_heads * head_size)
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, is_neox_style
|
||||
)
|
||||
|
||||
return query_x, key_x
|
||||
|
||||
# Rotary embedding
|
||||
@staticmethod
|
||||
def mrotary_embedding(
|
||||
positions,
|
||||
mrope_section,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style):
|
||||
positions, mrope_section, query, key, head_size, cos_sin_cache, is_neox_style
|
||||
):
|
||||
"""
|
||||
refactor RotaryEmbedding forward function
|
||||
"""
|
||||
query_x = query.contiguous()
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
assert is_neox_style
|
||||
xtorch_ops.mrotary_embedding_neox(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
mrope_section)
|
||||
kunlun_ops.mrotary_embedding_neox(
|
||||
positions, query_x, key_x, head_size, cos_sin_cache, mrope_section
|
||||
)
|
||||
|
||||
query.data = query_x
|
||||
key.data = key_x
|
||||
key.data = key_x
|
||||
return query, key
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping):
|
||||
""" swap_blocks """
|
||||
xtorch_ops.swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping
|
||||
)
|
||||
def swap_blocks(src, dst, block_mapping):
|
||||
"""swap_blocks"""
|
||||
kunlun_ops.swap_blocks(src, dst, block_mapping)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping):
|
||||
""" copy_blocks """
|
||||
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||
"""copy_blocks"""
|
||||
for i in range(len(key_caches)):
|
||||
key_caches[i] = key_caches[i].contiguous()
|
||||
value_caches[i] = value_caches[i].contiguous()
|
||||
xtorch_ops.copy_blocks(
|
||||
kunlun_ops.copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping,
|
||||
@@ -276,16 +236,10 @@ class KunlunOps:
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
):
|
||||
""" reshape_and_cache """
|
||||
):
|
||||
"""reshape_and_cache"""
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping
|
||||
)
|
||||
kunlun_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
@staticmethod
|
||||
def multi_query_kv_attention(
|
||||
@@ -294,7 +248,7 @@ class KunlunOps:
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
**kargs
|
||||
**kargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
@@ -304,18 +258,14 @@ class KunlunOps:
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
output = torch.empty_like(query)
|
||||
alibi_slopes = kargs.get("alibi_slopes", None)
|
||||
mask = kargs.get("mask", None)
|
||||
is_causal = kargs.get("is_causal", True)
|
||||
is_lvsl = kargs.get("is_lvsl", True)
|
||||
|
||||
B, T, Qh, Hd = query.shape
|
||||
KVh = key.size(2)
|
||||
if KVh != Qh:
|
||||
repeat = Qh // KVh
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
value = value.repeat_interleave(repeat, dim=2)
|
||||
xtorch_ops.attention(
|
||||
kunlun_ops.attention(
|
||||
q=query,
|
||||
k_cache=key,
|
||||
v_cache=value,
|
||||
@@ -328,80 +278,90 @@ class KunlunOps:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def quant_fusedresidual_rmsnorm_op(x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale: bool,
|
||||
type: int = 1):
|
||||
def quant_fusedresidual_rmsnorm_op(
|
||||
x, residual, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
"""Quantized fused residual layer normalization"""
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
||||
kunlun_ops.quant_fusedresidual_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
out=out,
|
||||
out_scale=out_scale,
|
||||
residual_tensor=residual,
|
||||
)
|
||||
|
||||
if residual is None:
|
||||
return out, out_scale
|
||||
return out, out_scale, residual
|
||||
|
||||
@staticmethod
|
||||
def quant_rmsnorm_op(x,
|
||||
weight,
|
||||
bias,
|
||||
scale_to_int,
|
||||
eps,
|
||||
dyn_scale : bool,
|
||||
type: int = 1):
|
||||
def quant_rmsnorm_op(
|
||||
x, weight, bias, scale_to_int, eps, dyn_scale: bool, type: int = 1
|
||||
):
|
||||
"""Quantized RMSNorm"""
|
||||
|
||||
out = torch.empty_like(x, dtype=torch.int8)
|
||||
if is_per_token_smooth_quant():
|
||||
out_scale = torch.empty(x.shape[:-1], device=x.device, dtype=torch.float).unsqueeze(-1)
|
||||
out_scale = torch.empty(
|
||||
x.shape[:-1], device=x.device, dtype=torch.float
|
||||
).unsqueeze(-1)
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||
out=out, out_scale=out_scale)
|
||||
kunlun_ops.quant_rmsnorm(x, weight, bias, eps, out=out, out_scale=out_scale)
|
||||
return out, out_scale
|
||||
|
||||
@staticmethod
|
||||
def smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype):
|
||||
def smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
otype,
|
||||
):
|
||||
"""smooth_quant_matmul_column_row_kernels"""
|
||||
input_shape = input_tensor.shape
|
||||
weight_shape = weight.shape
|
||||
if input_tensor.dim() == 3:
|
||||
input_tensor = input_tensor.reshape(-1, input_shape[-1])
|
||||
out = torch.empty((input_shape[0] * input_shape[1],
|
||||
weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
out = torch.empty(
|
||||
(input_shape[0] * input_shape[1], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
output_bs_shape = [input_shape[0], input_shape[1]]
|
||||
elif input_tensor.dim() == 2:
|
||||
out = torch.empty((input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
out = torch.empty(
|
||||
(input_shape[0], weight_shape[0]),
|
||||
dtype=torch.float16,
|
||||
device=weight.device,
|
||||
)
|
||||
output_bs_shape = [-1]
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight, smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out)
|
||||
kunlun_ops.smooth_quant_matmul_column_row_kernels(
|
||||
input_tensor,
|
||||
weight,
|
||||
smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
perTokenScaling,
|
||||
perChannelScaling,
|
||||
out=out,
|
||||
)
|
||||
|
||||
out = out.view(*output_bs_shape, weight_shape[0])
|
||||
|
||||
@@ -411,6 +371,7 @@ class KunlunOps:
|
||||
if torch.is_tensor(x):
|
||||
return (type(x), x.device, x.dtype, x.shape, x.is_contiguous())
|
||||
return (type(x), x)
|
||||
|
||||
@staticmethod
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -427,23 +388,24 @@ class KunlunOps:
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""fused_moe"""
|
||||
global_num_experts, up_gate_size, _ = w1.shape
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = w2.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
moe_top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
normed_score = torch.empty(
|
||||
M, moe_top_k, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
M, moe_top_k, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
num_blocks,
|
||||
global_num_experts,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
router_logits = router_logits.to(torch.float)
|
||||
if scoring_func == "softmax":
|
||||
@@ -452,24 +414,27 @@ class KunlunOps:
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
stable=False,
|
||||
)
|
||||
elif scoring_func == "sigmoid":
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
x=router_logits,
|
||||
topk_index=topk_ids,
|
||||
norm_score=normed_score,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
scale=1.0,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
)
|
||||
|
||||
if w1_bias is not None or w2_bias is not None:
|
||||
if w1_bias is not None or w2_bias is not None:
|
||||
# Rignt now this branch is for gpt oss
|
||||
# TODO (@xyDong23): faster here using moe_fc kernel
|
||||
normed_score = normed_score.to(hidden_states.dtype)
|
||||
out = torch.zeros(M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
out = torch.zeros(
|
||||
M * moe_top_k, N, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
repeat_x = hidden_states.repeat_interleave(moe_top_k, dim=0)
|
||||
topk_ids_flat = topk_ids.flatten()
|
||||
for i in range(global_num_experts):
|
||||
@@ -477,9 +442,13 @@ class KunlunOps:
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
groupgemm1 = cur_token@ w1[i].T
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
groupgemm1 = cur_token @ w1[i].T
|
||||
# Add w13 bias
|
||||
if w1_bias is not None:
|
||||
groupgemm1 = groupgemm1 + w1_bias[i]
|
||||
@@ -489,53 +458,129 @@ class KunlunOps:
|
||||
if w2_bias is not None:
|
||||
groupgemm2 = groupgemm2 + w2_bias[i]
|
||||
out[selected_token] = groupgemm2
|
||||
ouput = (out.view(M, moe_top_k, N) * normed_score.unsqueeze(2)).sum(dim=1).to(hidden_states.dtype)
|
||||
ouput = (
|
||||
(out.view(M, moe_top_k, N) * normed_score.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(hidden_states.dtype)
|
||||
)
|
||||
return ouput
|
||||
else:
|
||||
moe_expand = torch.empty((M * moe_top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * moe_top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
# from vllm.forward_context import get_forward_context
|
||||
# forward_context = get_forward_context()
|
||||
# attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
# prefix = "model.layers.0.linear_attn"
|
||||
# if attn_metadata is not None:
|
||||
# attn_metadata = attn_metadata[prefix]
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
# if attn_metadata is None or attn_metadata.num_prefills > 0 or :
|
||||
if M * moe_top_k < 400:
|
||||
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
|
||||
torch.ops.xspeedgate_ops.moe_pre_small(
|
||||
topk_ids, global_num_experts, False, False, hidden_states
|
||||
)
|
||||
)
|
||||
experts_num_lod = torch.ops.xspeedgate_ops.moe_active_expert_balance(
|
||||
topk_ids, global_num_experts, False
|
||||
)
|
||||
out = torch.ops.xspeedgate_ops.fused_moe(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
normed_score.to(hidden_states.dtype),
|
||||
sorted_tokens_num_lod,
|
||||
sorted_tokens_idx,
|
||||
experts_num_lod,
|
||||
)
|
||||
return out.sum(1)
|
||||
|
||||
y = torch.empty(M,moe_top_k,
|
||||
w1.shape[1],
|
||||
if M * moe_top_k > 768:
|
||||
moe_expand = torch.empty(
|
||||
(M * moe_top_k, N),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
device=hidden_states.device,
|
||||
) # [M*top_k, N], float
|
||||
expert_m = torch.zeros(
|
||||
global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(
|
||||
global_num_experts + 1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device,
|
||||
) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(
|
||||
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
)
|
||||
else:
|
||||
sorted_tokens_idx, sorted_tokens_num_lod, moe_expand = (
|
||||
torch.ops.xspeedgate_ops.moe_pre_small(
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
index_have_neg=False,
|
||||
sort_mode=True,
|
||||
x=hidden_states,
|
||||
)
|
||||
)
|
||||
|
||||
y = torch.empty(
|
||||
M,
|
||||
moe_top_k,
|
||||
w1.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
if M < 1024:
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = y.shape[:-1] + (d,)
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
else:
|
||||
torch.ops._C.moe_fc(
|
||||
x=moe_expand,
|
||||
weight=w1,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=moe_top_k,
|
||||
y=y,
|
||||
act="SWISH_GLU",
|
||||
)
|
||||
|
||||
y = y[..., : y.shape[-1] // 2]
|
||||
out1 = y.reshape(-1, y.shape[-1])
|
||||
|
||||
out = torch.empty(
|
||||
M,
|
||||
moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,moe_top_k,
|
||||
w2.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=out1,
|
||||
weight=w2,
|
||||
@@ -545,8 +590,12 @@ class KunlunOps:
|
||||
y=out,
|
||||
)
|
||||
|
||||
dequant_scale = torch.ones([M, moe_top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
dequant_scale = torch.ones(
|
||||
[M, moe_top_k], dtype=torch.float32, device=out.device
|
||||
)
|
||||
output = torch.empty(
|
||||
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, moe_top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
@@ -554,9 +603,9 @@ class KunlunOps:
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
y=output,
|
||||
)
|
||||
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@@ -575,23 +624,23 @@ class KunlunOps:
|
||||
topk_group: Optional[int] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
x = hidden_states
|
||||
batch, hidden_size = x.shape
|
||||
batch, hidden_size = x.shape
|
||||
num_local_experts, up_gate_size, _ = w13_weight.shape
|
||||
|
||||
router_logits = x.to(linear_weights.dtype)@linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=router_logits.dtype,
|
||||
device=router_logits.device)
|
||||
topk_ids = torch.empty(batch,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=router_logits.device)
|
||||
block_static = torch.empty(0, dtype=torch.int32,device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(router_logits, topk_weights, topk_ids, block_static)
|
||||
router_logits = x.to(linear_weights.dtype) @ linear_weights.T
|
||||
|
||||
topk_weights = torch.empty(
|
||||
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
|
||||
)
|
||||
topk_ids = torch.empty(
|
||||
batch, top_k, dtype=torch.int32, device=router_logits.device
|
||||
)
|
||||
block_static = torch.empty(0, dtype=torch.int32, device=router_logits.device)
|
||||
torch.ops._C.moe_softmax_topk(
|
||||
router_logits, topk_weights, topk_ids, block_static
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(1, keepdim=True)
|
||||
@@ -605,11 +654,19 @@ class KunlunOps:
|
||||
selected_token = topk_ids_flat == experts_id
|
||||
if selected_token.sum():
|
||||
cur_token = repeat_x[selected_token]
|
||||
up_gate = torch.empty(selected_token.sum(), up_gate_size//2,
|
||||
dtype=cur_token.dtype, device=cur_token.device)
|
||||
torch.ops._C.silu_and_mul(up_gate, cur_token@ w13_weight[i].T)
|
||||
up_gate = torch.empty(
|
||||
selected_token.sum(),
|
||||
up_gate_size // 2,
|
||||
dtype=cur_token.dtype,
|
||||
device=cur_token.device,
|
||||
)
|
||||
torch.ops._C.silu_and_mul(up_gate, cur_token @ w13_weight[i].T)
|
||||
out[selected_token] = up_gate @ w2_weight[i].T
|
||||
output = (out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2)).sum(dim=1).to(x.dtype)
|
||||
output = (
|
||||
(out.view(batch, top_k, hidden_size) * topk_weights.unsqueeze(2))
|
||||
.sum(dim=1)
|
||||
.to(x.dtype)
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -645,11 +702,12 @@ class KunlunOps:
|
||||
prompt_lods_cpu: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor:
|
||||
"""mla pa block"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
xtorch_ops.xft_multi_head_latent_page_attention_block(
|
||||
output = torch.empty(
|
||||
hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
kunlun_ops.xft_multi_head_latent_page_attention_block(
|
||||
hidden_states,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -686,7 +744,6 @@ class KunlunOps:
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_gdn_gating(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
@@ -695,32 +752,41 @@ class KunlunOps:
|
||||
threshold: float = 20.0,
|
||||
) -> torch.Tensor:
|
||||
"""fused_gdn_gating"""
|
||||
output = xtorch_ops.fused_gdn_gating(
|
||||
output = kunlun_ops.fused_gdn_gating(
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
h0_source: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool,
|
||||
cu_seqlens: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
'''
|
||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
'''
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
h0_source: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Qwen3-NEXT模型中 Gated DeltaNet的核心算子, 将做完sigmoid_gating和delta_rule_update融合在一起
|
||||
1. Sigmoid Gating: 对输入进行门控, 类似于 GLU (Gated Linear Unit)。
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
"""
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
||||
cu_seqlens)
|
||||
return (o, final_state)
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
h0_source,
|
||||
output_final_state,
|
||||
use_qk_l2norm_in_kernel,
|
||||
cu_seqlens,
|
||||
)
|
||||
return (o, final_state)
|
||||
|
||||
@@ -93,7 +93,7 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_cuda"""
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
@@ -103,7 +103,7 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""forward_kunlun"""
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
@@ -251,14 +251,14 @@ class GeluAndMul(CustomOp):
|
||||
无。
|
||||
"""
|
||||
# from vllm import _custom_ops as ops
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
# d = x.shape[-1] // 2
|
||||
# output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(x, dtype=x.dtype, device=x.device)
|
||||
if self.approximate == "none":
|
||||
# ops.gelu_and_mul(out, x)
|
||||
print(x,x.shape)
|
||||
xtorch_ops.gelu(x, out)
|
||||
kunlun_ops.gelu(x, out)
|
||||
elif self.approximate == "tanh":
|
||||
ops.gelu_tanh_and_mul(out, x)
|
||||
return out
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -104,7 +104,7 @@ def flash_mla_with_kvcache(
|
||||
is_context = False
|
||||
vo_head_dim = -1
|
||||
|
||||
xtorch_ops.paged_attention(out,
|
||||
kunlun_ops.paged_attention(out,
|
||||
q,
|
||||
k_cache, None,
|
||||
block_table,
|
||||
@@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache(
|
||||
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
"""
|
||||
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
||||
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ def merge_attn_states(
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
return xtorch_ops.attention_merge_stage(
|
||||
return kunlun_ops.attention_merge_stage(
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
|
||||
@@ -9,60 +9,196 @@
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cocopod # noqa
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
from .index import prepare_chunk_indices
|
||||
import xspeedgate_ops
|
||||
import cocopod
|
||||
|
||||
|
||||
def torch_solve_tril(A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float,):
|
||||
chunk_size=64
|
||||
A = -A.transpose(1,2)
|
||||
def torch_solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
):
|
||||
chunk_size = 64
|
||||
A = -A.transpose(1, 2)
|
||||
sequence_length = A.shape[-2]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
|
||||
|
||||
# A = A.masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = A[..., i, :i].clone()
|
||||
sub = A[..., :i, :i].clone()
|
||||
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[:,:,:sequence_length,:].transpose(1,2)
|
||||
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
|
||||
:, :, :sequence_length, :
|
||||
].transpose(1, 2)
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=q.dtype)
|
||||
|
||||
#kernel版
|
||||
torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
def recompute_w_u_fwd_torch(
|
||||
k: torch.Tensor, # [B, T, H, K]
|
||||
v: torch.Tensor, # [B, T, H, V]
|
||||
beta: torch.Tensor, # [B, T, H]
|
||||
g: torch.Tensor, # [B, T, H]
|
||||
A: torch.Tensor, # [B, H, T, T]
|
||||
):
|
||||
"""
|
||||
最简单版本:假设等长序列,key和value头数相同
|
||||
"""
|
||||
chunk_size = 64
|
||||
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k, v, beta, g, A = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
|
||||
]
|
||||
|
||||
batch_size, num_heads, sequence_length, k_head_dim = k.shape
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
k = F.pad(k, (0, 0, 0, pad_size))
|
||||
v = F.pad(v, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
A = F.pad(A, (0, 0, 0, pad_size))
|
||||
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
|
||||
|
||||
v_beta = v * beta.unsqueeze(-1)
|
||||
k_beta = k * beta.unsqueeze(-1)
|
||||
|
||||
k, v, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (k, v, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
|
||||
u = A @ v_beta
|
||||
w = A @ (k_beta * g.exp().unsqueeze(-1))
|
||||
w = (
|
||||
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
u = (
|
||||
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
return w, u
|
||||
|
||||
|
||||
def split_by_value(tensor, chunk_size=64):
|
||||
indices = tensor.tolist()
|
||||
result = set(indices) # 使用集合避免重复
|
||||
|
||||
for i in range(len(indices) - 1):
|
||||
start = indices[i]
|
||||
end = indices[i + 1]
|
||||
|
||||
# 计算第一个对齐边界
|
||||
# 我们要找的是 start + n*chunk_size,其中n是使结果大于start的最小整数
|
||||
first_boundary = start + chunk_size
|
||||
|
||||
# 在(start, end)范围内插入所有对齐边界
|
||||
boundary = first_boundary
|
||||
while boundary < end:
|
||||
result.add(boundary)
|
||||
boundary += chunk_size
|
||||
|
||||
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
|
||||
)
|
||||
chunk_offsets = (
|
||||
prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# !
|
||||
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
|
||||
g,
|
||||
chunk_size=64,
|
||||
reverse=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
head_first=False,
|
||||
)
|
||||
|
||||
# !
|
||||
# A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
# beta=beta,
|
||||
# g_cumsum=g,
|
||||
# cu_seqlens=cu_seqlens,
|
||||
# output_dtype=q.dtype)
|
||||
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
|
||||
k, beta, g, cu_seqlens, chunk_indices, chunk_size
|
||||
)
|
||||
|
||||
# torch版
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# torch.save(A, "A_in")
|
||||
# torch.save(cu_seqlens, "cu_seqlens")
|
||||
# A2 = A.clone()
|
||||
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
|
||||
|
||||
# !
|
||||
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# err = torch.max(torch.abs(A - A2))
|
||||
# print("err", err)
|
||||
# if err > 1e-3:
|
||||
# raise
|
||||
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
# for i in range(len(cu_seqlens)-1):
|
||||
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
|
||||
|
||||
"""
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
for i in range(len(cu_seqlens)-1):
|
||||
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
|
||||
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
|
||||
|
||||
w_i, u_i = recompute_w_u_fwd_torch(
|
||||
k=k_i,
|
||||
v=v_i,
|
||||
beta=beta_i,
|
||||
A=A_i,
|
||||
g=g_i,
|
||||
)
|
||||
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
|
||||
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
|
||||
"""
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -71,17 +207,63 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
chunk_size=64,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
"""
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
"""
|
||||
|
||||
# i
|
||||
# import os
|
||||
# if not os.path.exists("/qwen-next/in"):
|
||||
# os.makedirs("/qwen-next/in")
|
||||
# torch.save(k, "/qwen-next/in/k.pt")
|
||||
# torch.save(u, "/qwen-next/in/u.pt")
|
||||
# torch.save(w, "/qwen-next/in/w.pt")
|
||||
# torch.save(g, "/qwen-next/in/g.pt")
|
||||
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
|
||||
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
|
||||
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
|
||||
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
|
||||
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
|
||||
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
|
||||
|
||||
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
|
||||
k,
|
||||
u,
|
||||
w,
|
||||
g,
|
||||
initial_state,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
chunk_offsets.to(torch.int32),
|
||||
chunk_size,
|
||||
output_final_state,
|
||||
True,
|
||||
)
|
||||
|
||||
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
# k=k,
|
||||
# w=w,
|
||||
# u=u,
|
||||
# g=g,
|
||||
# initial_state=initial_state,
|
||||
# output_final_state=output_final_state,
|
||||
# cu_seqlens=cu_seqlens,
|
||||
# )
|
||||
# if not os.path.exists("/qwen-next/out"):
|
||||
# os.makedirs("/qwen-next/out")
|
||||
# torch.save(h, "/qwen-next/out/h.pt")
|
||||
# torch.save(v_new, "/qwen-next/out/v_new.pt")
|
||||
# torch.save(final_state, "/qwen-next/out/final_state.pt")
|
||||
|
||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
@@ -91,8 +273,19 @@ def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_size=64
|
||||
chunk_size=64,
|
||||
)
|
||||
"""
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
"""
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
@@ -103,18 +296,20 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
@@ -136,17 +331,19 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
@@ -211,42 +408,85 @@ def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
assert (
|
||||
q.dtype != torch.float32
|
||||
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert (
|
||||
len(beta.shape) == 3
|
||||
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
||||
)
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
stacklevel=2,
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
if False:
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
g = g.contiguous()
|
||||
beta = beta.contiguous()
|
||||
initial_state = initial_state.contiguous()
|
||||
|
||||
o = torch.empty_like(v)
|
||||
final_state = torch.empty_like(initial_state)
|
||||
import kunlun_ops
|
||||
|
||||
kunlun_ops.gated_delta_rule(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
initial_state,
|
||||
g,
|
||||
beta,
|
||||
final_state,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens.cpu(),
|
||||
cu_seqlens,
|
||||
cu_seqlens.cpu(),
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
else:
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
return o, final_state
|
||||
|
||||
@@ -12,21 +12,21 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({
|
||||
@@ -40,7 +40,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
# ],
|
||||
# key=['H', 'K', 'V', 'BT'],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
@@ -67,10 +67,12 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
@@ -89,12 +91,15 @@ def chunk_fwd_kernel_o(
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
||||
(BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
||||
(BK, BV), (1, 0))
|
||||
p_q = tl.make_block_ptr(
|
||||
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||
)
|
||||
p_h = tl.make_block_ptr(
|
||||
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
||||
)
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
@@ -109,8 +114,8 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_o = b_o * tl.exp(b_g)[:, None]
|
||||
b_A = b_A * tl.exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
@@ -120,10 +125,12 @@ def chunk_fwd_kernel_o(
|
||||
# b_A = tl.where(m_A, b_A, 0)
|
||||
b_A = tl.where(o_t[:, None] >= o_t[None, :], b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
@@ -133,48 +140,29 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
_, T, _, _, _ = *q.shape, v.shape[-1]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
BV=32
|
||||
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
|
||||
q, k, v, h, g, scale, cu_seqlens, chunk_indices, chunk_size
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -9,29 +9,29 @@
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||
q.contiguous(),
|
||||
k.contiguous(),
|
||||
v.contiguous(),
|
||||
@@ -44,7 +44,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
h0_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
is_h0_transposed=True
|
||||
is_h0_transposed=True,
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
@@ -130,9 +130,10 @@ def fused_recurrent_gated_delta_rule(
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
|
||||
@@ -10,22 +10,21 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
import xtorch_ops
|
||||
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
@@ -49,11 +48,14 @@ def l2norm_fwd_kernel1(
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BT": BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16]
|
||||
for BT in BT_LIST
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
@@ -87,67 +89,9 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd_triton(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
def l2norm_fwd(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
out = torch.empty_like(x)
|
||||
xtorch_ops.l2norm(x, out, eps)
|
||||
kunlun_ops.l2norm(x, out, eps)
|
||||
return out
|
||||
|
||||
@@ -19,20 +19,21 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True):
|
||||
def rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True,
|
||||
):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
@@ -43,12 +44,10 @@ def rms_norm_ref(x,
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
@@ -57,10 +56,12 @@ def rms_norm_ref(x,
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
@@ -97,17 +98,17 @@ def layer_norm_fwd_kernel(
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
@@ -149,46 +150,50 @@ def layer_norm_fwd(
|
||||
# weight = weight.reshape(N)
|
||||
# print("weight",weight.shape)
|
||||
# print("x",x.shape)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
layer_norm_fwd_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
@@ -196,17 +201,18 @@ class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
@@ -223,16 +229,15 @@ class LayerNormFn(torch.autograd.Function):
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
# y, mean, rstd = torch.ops.xspeedgate_ops.rms_norm_gated_fwd(x, weight, bias, eps, z, group_size, norm_before_gate, is_rms_norm)
|
||||
y = torch.empty_like(x)
|
||||
mean, rstd = None, None
|
||||
import kunlun_ops
|
||||
|
||||
kunlun_ops.rms_norm_gated(
|
||||
x, y, z, weight, eps, group_size, norm_before_gate, is_rms_norm
|
||||
)
|
||||
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
@@ -242,27 +247,27 @@ class LayerNormFn(torch.autograd.Function):
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
def layernorm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
@@ -294,15 +299,16 @@ class LayerNormGated(nn.Module):
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
@@ -332,12 +338,13 @@ class RMSNormGated(nn.Module):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
@@ -28,6 +27,7 @@ RESOLUTION = {
|
||||
torch.complex64: 1.3e-6,
|
||||
}
|
||||
|
||||
|
||||
def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
assert res.dtype == dtype
|
||||
ref = ref.to(dtype)
|
||||
@@ -35,6 +35,7 @@ def assert_close(res, ref, dtype, equal_nan=False, reduce_dim=1):
|
||||
rtol = RESOLUTION[dtype]
|
||||
torch.testing.assert_close(res, ref, atol=atol, rtol=rtol, equal_nan=equal_nan)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
@@ -80,7 +81,6 @@ def recompute_u_fwd_kernel(
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
@@ -110,7 +110,6 @@ def recompute_u_fwd_kernel(
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
@@ -195,53 +194,12 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
recompute_w_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
|
||||
k, v, beta, g_cumsum, A, cu_seqlens, chunk_indices, chunk_size=BT
|
||||
)
|
||||
return w, u
|
||||
return w, u
|
||||
|
||||
@@ -15,51 +15,52 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||
from vllm.model_executor.layers import layernorm
|
||||
from typing import Optional, Union
|
||||
import xtorch_ops
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers import layernorm
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
def vllm_kunlun_forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if x.is_contiguous() == False:
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""forward_cuda"""
|
||||
if not x.is_contiguous():
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
if self.variance_size_override is not None:
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x
|
||||
)
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
if residual is not None:
|
||||
# residual_output = torch.empty_like(residual)
|
||||
torch.ops._C.add_rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=self.weight.data,
|
||||
eps=self.variance_epsilon,
|
||||
output=x,
|
||||
)
|
||||
return out
|
||||
return x, residual
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
self.weight.data,
|
||||
out,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
|
||||
|
||||
class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
@staticmethod
|
||||
def forward_xpu(
|
||||
@@ -68,30 +69,42 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if x.is_contiguous() == False:
|
||||
if not x.is_contiguous():
|
||||
# kunlun does not support uncontiguous input and they do not think it is a bug
|
||||
# so we must make it contiguous() manually
|
||||
x = x.contiguous()
|
||||
|
||||
if x.dim() == 3:
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x.size(-1))
|
||||
if residual is not None:
|
||||
torch.ops._C.add_rmsnorm(
|
||||
out = torch.empty_like(x)
|
||||
out_residual = torch.empty_like(residual)
|
||||
torch.ops._C.gemma_add_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
residual_output=residual,
|
||||
weight=weight+1,
|
||||
residual_output=out_residual,
|
||||
weight=weight,
|
||||
eps=variance_epsilon,
|
||||
output=x
|
||||
output=out,
|
||||
)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.gemma_rmsnorm(
|
||||
x,
|
||||
weight,
|
||||
out,
|
||||
variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
torch.ops._C.rmsnorm(
|
||||
x,
|
||||
weight+1,
|
||||
out,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
if x.dim() == 3:
|
||||
x = x.view(x_shape)
|
||||
if out is not None:
|
||||
out = out.view(x_shape)
|
||||
|
||||
if residual is not None:
|
||||
return out, out_residual
|
||||
else:
|
||||
return out
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@@ -99,16 +112,17 @@ class KunlunGemmaRMSNorm(OriGemmaRMSNorm):
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
if torch.compiler.is_compiling():
|
||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||
self.forward_static = self.forward_xpu # only use in cudagraph
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
if not getattr(self, "_is_compiled", False):
|
||||
self.forward_static = torch.compile( # type: ignore
|
||||
self.forward_static, backend="aot_eager")
|
||||
self.forward_static, backend="aot_eager"
|
||||
)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
RMSNorm.forward_cuda = vllm_kunlun_forward_cuda
|
||||
RMSNorm.forward = vllm_kunlun_forward_cuda
|
||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||
layernorm.GemmaRMSNorm = KunlunGemmaRMSNorm
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -113,7 +113,7 @@ class KunlunCompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# NOTE: xtorch_ops use max as scale
|
||||
# NOTE: kunlun_ops use max as scale
|
||||
with torch.no_grad():
|
||||
layer.w13_weight_scale.mul_(127.0)
|
||||
layer.w2_weight_scale.mul_(127.0)
|
||||
|
||||
@@ -8,9 +8,6 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
|
||||
# fix bfloat16 double size issue
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class KunlunPlatform(Platform):
|
||||
|
||||
21
vllm_kunlun/transformer_utils/__init__.py
Normal file
21
vllm_kunlun/transformer_utils/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
#
|
||||
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
|
||||
|
||||
from . import tokenizer
|
||||
|
||||
__all__ = ["tokenizer"]
|
||||
27
vllm_kunlun/transformer_utils/config.py
Normal file
27
vllm_kunlun/transformer_utils/config.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.transformers_utils.config import LazyConfigDict, _CONFIG_REGISTRY
|
||||
|
||||
_XPU_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
chatglm="ChatGLMConfig",
|
||||
deepseek_vl_v2="DeepseekVLV2Config",
|
||||
deepseek_v3="DeepseekV3Config",
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
glm_moe_dsa="DeepseekV3Config",
|
||||
kimi_vl="KimiVLConfig",
|
||||
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
|
||||
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
||||
RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct)
|
||||
jais="JAISConfig",
|
||||
mlp_speculator="MLPSpeculatorConfig",
|
||||
medusa="MedusaConfig",
|
||||
midashenglm="MiDashengLMConfig",
|
||||
eagle="EAGLEConfig",
|
||||
speculators="SpeculatorsConfig",
|
||||
nemotron="NemotronConfig",
|
||||
olmo3="Olmo3Config",
|
||||
ovis="OvisConfig",
|
||||
ultravox="UltravoxConfig",
|
||||
step3_vl="Step3VLConfig",
|
||||
step3_text="Step3TextConfig",
|
||||
qwen3_next="Qwen3NextConfig",
|
||||
)
|
||||
223
vllm_kunlun/transformer_utils/tokenizer.py
Normal file
223
vllm_kunlun/transformer_utils/tokenizer.py
Normal file
@@ -0,0 +1,223 @@
|
||||
#
|
||||
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
|
||||
|
||||
# This file is a part of the vllm-kunlun project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import shutil
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import huggingface_hub
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils import tokenizer
|
||||
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
|
||||
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||
else:
|
||||
TokenizerBase = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, TokenizerBase]
|
||||
|
||||
|
||||
def kunlun_get_tokenizer(
|
||||
tokenizer_name: Union[str, Path],
|
||||
*args,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
revision: Optional[str] = None,
|
||||
download_dir: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> AnyTokenizer:
|
||||
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# avoid circuit import
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
|
||||
# Only set the tokenizer here, model will be downloaded on the workers.
|
||||
if not os.path.exists(tokenizer_name):
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same file at the same time.
|
||||
with get_lock(tokenizer_name, download_dir):
|
||||
tokenizer_path = snapshot_download(
|
||||
model_id=tokenizer_name,
|
||||
cache_dir=download_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
# Ignore weights - we only need the tokenizer.
|
||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
||||
)
|
||||
tokenizer_name = tokenizer_path
|
||||
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
|
||||
if "truncation_side" not in kwargs:
|
||||
kwargs["truncation_side"] = "left"
|
||||
|
||||
# Separate model folder from file path for GGUF models
|
||||
is_gguf = check_gguf_file(tokenizer_name)
|
||||
if is_gguf:
|
||||
kwargs["gguf_file"] = Path(tokenizer_name).name
|
||||
tokenizer_name = Path(tokenizer_name).parent
|
||||
|
||||
# if tokenizer is from official mistral org
|
||||
is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai"
|
||||
if is_from_mistral_org and tokenizer_mode != "mistral":
|
||||
warnings.warn(
|
||||
"It is strongly recommended to run mistral models with "
|
||||
'`--tokenizer-mode "mistral"` to ensure correct '
|
||||
"encoding and decoding.",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tokenizer: AnyTokenizer
|
||||
if tokenizer_mode == "mistral":
|
||||
tokenizer = MistralTokenizer.from_pretrained(
|
||||
str(tokenizer_name), revision=revision
|
||||
)
|
||||
elif tokenizer_mode == "custom":
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer(
|
||||
str(tokenizer_name),
|
||||
*args,
|
||||
revision=revision,
|
||||
download_dir=download_dir,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=tokenizer_name,
|
||||
*args,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the tokenizer class not existing or not
|
||||
# currently being imported,
|
||||
# suggest using the --trust-remote-code flag.
|
||||
|
||||
if not trust_remote_code and (
|
||||
"does not exist or is not currently imported." in str(e)
|
||||
or "requires you to execute the tokenizer file" in str(e)
|
||||
):
|
||||
err_msg = (
|
||||
"Failed to load the tokenizer. If the tokenizer "
|
||||
"is a custom tokenizer not yet available in the "
|
||||
"HuggingFace transformers library, consider "
|
||||
"setting `trust_remote_code=True` in LLM or using "
|
||||
"the `--trust-remote-code` flag in the CLI."
|
||||
)
|
||||
raise RuntimeError(err_msg) from e
|
||||
|
||||
# FIXME: Temporary compatibility code for new config format. Remove after vLLM upgrade.
|
||||
if "TokenizersBackend" in str(e):
|
||||
logger.warning(
|
||||
"TokenizerBackend not supported, patching tokenizer_config.json "
|
||||
"and loading with PreTrainedTokenizerFast."
|
||||
)
|
||||
tmp_dir = tempfile.mkdtemp(prefix="vllm_tokenizer_patch_")
|
||||
try:
|
||||
TOKENIZER_FILES = [
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"added_tokens.json",
|
||||
"chat_template.jinja",
|
||||
"generation_config.json",
|
||||
]
|
||||
|
||||
for fname in TOKENIZER_FILES:
|
||||
src = os.path.join(tokenizer_name, fname)
|
||||
if os.path.exists(src):
|
||||
shutil.copy(src, tmp_dir)
|
||||
|
||||
config_path = os.path.join(tmp_dir, "tokenizer_config.json")
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
if cfg.get("tokenizer_class") in ("TokenizersBackend",):
|
||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
||||
if "extra_special_tokens" in cfg:
|
||||
cfg["additional_special_tokens"] = cfg.pop(
|
||||
"extra_special_tokens"
|
||||
)
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=2)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tmp_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
else:
|
||||
raise e
|
||||
|
||||
# The special_tokens in tokenizer should also be
|
||||
# controlled by do_lower_case in encoder_config
|
||||
encoder_config = get_sentence_transformer_tokenizer_config(
|
||||
tokenizer_name, revision
|
||||
)
|
||||
if isinstance(encoder_config, dict) and encoder_config.get(
|
||||
"do_lower_case", False
|
||||
):
|
||||
special_tokens_map = {
|
||||
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens_map)
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
"Using a slow tokenizer. This might cause a significant "
|
||||
"slowdown. Consider using a fast tokenizer instead."
|
||||
)
|
||||
tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
tokenizer.get_tokenizer = kunlun_get_tokenizer
|
||||
|
||||
logger.info_once(
|
||||
"[Monkey Patch Applied] >>> vllm.transformer_utils.tokenizer.get_tokenizer \
|
||||
--> vllm_kunlun.transformer_utils.tokenizer.kunlun_get_tokenizer"
|
||||
)
|
||||
390
vllm_kunlun/v1/attention/backends/gdn_attn.py
Normal file
390
vllm_kunlun/v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends import gdn_attn
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class GDNAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_spec_decodes: int
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: Optional[torch.Tensor] = None
|
||||
has_initial_state_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
spec_query_start_loc: Optional[torch.Tensor] = (
|
||||
None # shape: [num_spec_decodes + 1,]
|
||||
)
|
||||
non_spec_query_start_loc: Optional[torch.Tensor] = (
|
||||
None # shape: [batch - num_spec_decodes + 1,]
|
||||
)
|
||||
|
||||
spec_state_indices_tensor: Optional[torch.Tensor] = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: Optional[torch.Tensor] = (
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
non_spec_state_indices_tensor_cpu: Optional[torch.Tensor] = None
|
||||
spec_sequence_masks: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
spec_token_masks: Optional[torch.Tensor] = (
|
||||
None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||
)
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
|
||||
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens # noqa: E501
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build( # type: ignore[override]
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_masks = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_masks = torch.ones(
|
||||
(
|
||||
min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
has_initial_state_cpu = has_initial_state.cpu()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
has_initial_state_cpu = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert spec_token_masks is not None
|
||||
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
|
||||
spec_token_masks, non_blocking=True
|
||||
)
|
||||
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
|
||||
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[
|
||||
-1
|
||||
] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
if num_accepted_tokens is not None:
|
||||
num_accepted_tokens = num_accepted_tokens.to(torch.int32)
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
has_initial_state_cpu=has_initial_state_cpu,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor_cpu=(
|
||||
non_spec_state_indices_tensor.cpu()
|
||||
if non_spec_state_indices_tensor is not None
|
||||
else None
|
||||
),
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_masks=spec_token_masks,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
|
||||
)
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
|
||||
|
||||
gdn_attn.GDNAttentionMetadata = GDNAttentionMetadata
|
||||
gdn_attn.GDNAttentionMetadataBuilder = GDNAttentionMetadataBuilder
|
||||
@@ -28,9 +28,9 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import kunlun_ops
|
||||
import numpy as np
|
||||
import torch
|
||||
import xtorch_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
CommonAttentionMetadata,
|
||||
@@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
|
||||
def __post_init__(self):
|
||||
"""__post_init__"""
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||
|
||||
@property
|
||||
def is_all_encoder_attn_metadata_set(self):
|
||||
@@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder:
|
||||
"""build"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
common_prefix_len = common_prefix_len
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
||||
self.device, non_blocking=True
|
||||
@@ -770,28 +770,17 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
value = value.contiguous()
|
||||
if key_cache.is_contiguous():
|
||||
xtorch_ops.reshape_and_cache(
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
)
|
||||
else:
|
||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
cast_key_cache,
|
||||
cast_value_cache,
|
||||
updated_slot_mapping,
|
||||
)
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
key[: attn_metadata.num_actual_tokens],
|
||||
value[: attn_metadata.num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
BLHD_LAYOUT=False,
|
||||
)
|
||||
|
||||
assert attn_type == AttentionType.DECODER
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Only enforce this shape-constraint for decoder
|
||||
# self-attention
|
||||
@@ -811,7 +800,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
|
||||
# Prefix cache
|
||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||
xtorch_ops.prefill_attention(
|
||||
kunlun_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||
v=value_cache,
|
||||
@@ -827,7 +816,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
softmax_lse=None,
|
||||
)
|
||||
else:
|
||||
xtorch_ops.prefill_attention(
|
||||
kunlun_ops.prefill_attention(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
@@ -860,9 +849,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
decode_meta.block_tables * 2
|
||||
) # only test in Qwen3-Next
|
||||
|
||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
||||
sig = inspect.signature(kunlun_ops.speculative_attention)
|
||||
if "max_window_size" in sig.parameters:
|
||||
xtorch_ops.speculative_attention(
|
||||
kunlun_ops.speculative_attention(
|
||||
out=output[:num_decode_tokens],
|
||||
# Only MLA support q len > 1 right now
|
||||
q=decode_query.unsqueeze(0),
|
||||
@@ -890,7 +879,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
),
|
||||
)
|
||||
elif not attn_metadata.is_speculative:
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=decode_query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -910,7 +899,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
||||
out = output[:num_decode_tokens]
|
||||
assert out.is_contiguous()
|
||||
|
||||
xtorch_ops.speculative_attention(
|
||||
kunlun_ops.speculative_attention(
|
||||
out=out.view(batch_size, qlen, head_num, self.head_size),
|
||||
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
||||
k_cache=key_cache,
|
||||
|
||||
@@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
) * q_len
|
||||
sorted_tokens_idx = torch.arange(
|
||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||
xtorch_ops.mla_bmm_I8(
|
||||
kunlun_ops.mla_bmm_I8(
|
||||
x.contiguous(), # [1, 16, 512] torch.float16
|
||||
self.W_UV, # [16, 128, 512] torch.int8
|
||||
self.W_UV_SCALE, # [2048, 1] torch.float32
|
||||
@@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
tp_q_head_num=q.size(1)
|
||||
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
||||
softmax_lse.fill_(float('-inf'))
|
||||
xtorch_ops.attention(
|
||||
kunlun_ops.attention(
|
||||
q=q,
|
||||
k_cache=k,
|
||||
v_cache=maybe_padded_v,
|
||||
@@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
|
||||
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
|
||||
dtype=torch.float, device=kv_b_proj_weight.device)
|
||||
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||
kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||
self.W_UV = W_UV.contiguous()
|
||||
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
|
||||
else:
|
||||
@@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
@@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
sorted_tokens_idx = torch.arange(
|
||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||
extra_params = {"trans": False}
|
||||
xtorch_ops.mla_bmm_I8(
|
||||
kunlun_ops.mla_bmm_I8(
|
||||
decode_q_nope.contiguous(),
|
||||
self.W_UK_T,
|
||||
self.W_UK_SCALE,
|
||||
|
||||
@@ -10,7 +10,7 @@ from packaging import version
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
import os
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -200,16 +200,16 @@ def flashinfer_sample(
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
next_token_ids = xtorch_ops.top_p_sampling_from_probs(
|
||||
next_token_ids = kunlun_ops.top_p_sampling_from_probs(
|
||||
probs,top_p=p, deterministic=True)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
next_token_ids = xtorch_ops.top_k_sampling_from_probs(
|
||||
next_token_ids = kunlun_ops.top_k_sampling_from_probs(
|
||||
probs, top_k=k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
k = k.to(torch.int32)
|
||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
||||
next_token_ids = kunlun_ops.top_k_top_p_sampling_from_probs(
|
||||
probs, top_k=k, top_p=p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import kunlun_ops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -54,7 +53,7 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
@@ -81,7 +80,7 @@ class RejectionSampler(nn.Module):
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
@@ -124,11 +123,11 @@ class RejectionSampler(nn.Module):
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
@@ -179,25 +178,15 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
if min(num_draft_tokens) == 1 and max(
|
||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||
rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
)
|
||||
else:
|
||||
rejection_greedy_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
num_draft_tokens,
|
||||
max_spec_len,
|
||||
is_greedy,
|
||||
)
|
||||
kunlun_ops.rejection_greedy_sample(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
@@ -222,8 +211,9 @@ def rejection_sample(
|
||||
sampling_metadata,
|
||||
device,
|
||||
)
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
|
||||
rejection_random_sample_pytorch(
|
||||
kunlun_ops.rejection_random_sample(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -235,8 +225,7 @@ def rejection_sample(
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
# num_warps=1,
|
||||
no_draft_probs=draft_probs is None,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
@@ -374,7 +363,7 @@ def generate_uniform_probs(
|
||||
random values in the range [0, 1).
|
||||
"""
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
(num_tokens,),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
@@ -422,7 +411,7 @@ def sample_recovered_tokens(
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||
sample_recovered_tokens_pytorch(
|
||||
kunlun_ops.sample_recovered_tokens(
|
||||
recovered_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -430,16 +419,16 @@ def sample_recovered_tokens(
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=draft_probs is None,
|
||||
no_draft_probs=draft_probs is None,
|
||||
)
|
||||
return recovered_token_ids
|
||||
|
||||
|
||||
def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
output_token_ids, # [batch_size, 2]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
@@ -447,73 +436,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
|
||||
accept_req_mask = draft_token_ids == target_argmax
|
||||
output_token_ids[:, 0] = target_argmax
|
||||
bonus_token_ids = bonus_token_ids.squeeze(1)
|
||||
output_token_ids[:, 1] = torch.where(accept_req_mask, bonus_token_ids,
|
||||
output_token_ids[:, 1])
|
||||
output_token_ids[:, 1] = torch.where(
|
||||
accept_req_mask, bonus_token_ids, output_token_ids[:, 1]
|
||||
)
|
||||
|
||||
|
||||
def rejection_greedy_sample_pytorch(
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
output_token_ids, # [batch_size, max_spec_len + 1]
|
||||
cu_num_draft_tokens, # [batch_size]
|
||||
draft_token_ids, # [num_tokens]
|
||||
target_argmax, # [num_tokens]
|
||||
bonus_token_ids, # [batch_size]
|
||||
draft_tokens_per_req, # [batch_size], list
|
||||
max_spec_len,
|
||||
is_greedy=None, # [batch_size] or None
|
||||
):
|
||||
batch_size = output_token_ids.size(0)
|
||||
num_tokens = draft_token_ids.size(0)
|
||||
device = output_token_ids.device
|
||||
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
|
||||
device, non_blocking=True)
|
||||
device, non_blocking=True
|
||||
)
|
||||
if is_greedy is None:
|
||||
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
start_indices = cu_num_draft_tokens - draft_tokens_per_req
|
||||
req_ids = torch.arange(batch_size, device=device)
|
||||
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
|
||||
token_positions = torch.arange(
|
||||
num_tokens, device=device) - start_indices[token_req_ids]
|
||||
token_positions = (
|
||||
torch.arange(num_tokens, device=device) - start_indices[token_req_ids]
|
||||
)
|
||||
|
||||
# Find the first mismatch position of each request.
|
||||
mismatch_global = (draft_token_ids != target_argmax)
|
||||
mismatch_global = draft_token_ids != target_argmax
|
||||
if max_spec_len == 0:
|
||||
first_mismatch_pos_per_req = torch.zeros(batch_size,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
first_mismatch_pos_per_req = torch.zeros(
|
||||
batch_size, dtype=torch.long, device=device
|
||||
)
|
||||
else:
|
||||
# [bs, max_spec_len]
|
||||
pos_matrix = torch.full((batch_size, max_spec_len),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
pos_matrix = torch.full(
|
||||
(batch_size, max_spec_len), -1, dtype=torch.long, device=device
|
||||
)
|
||||
pos_matrix[token_req_ids, token_positions] = token_positions
|
||||
mismatch_matrix = torch.full((batch_size, max_spec_len),
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
mismatch_matrix = torch.full(
|
||||
(batch_size, max_spec_len), False, dtype=torch.bool, device=device
|
||||
)
|
||||
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
|
||||
max_spec_len * 2)
|
||||
mismatch_positions = torch.where(mismatch_matrix, pos_matrix, max_spec_len * 2)
|
||||
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
|
||||
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
|
||||
no_mismatch_mask = first_mismatch_pos_per_req == max_spec_len * 2
|
||||
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
|
||||
no_mismatch_mask]
|
||||
no_mismatch_mask
|
||||
]
|
||||
|
||||
# Copy matched target tokens into output.
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
|
||||
draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1,
|
||||
device=device).expand(batch_size, -1)
|
||||
copy_len = torch.minimum(first_mismatch_pos_per_req + 1, draft_tokens_per_req)
|
||||
copy_indices = torch.arange(max_spec_len + 1, device=device).expand(batch_size, -1)
|
||||
copy_mask = copy_indices < copy_len.unsqueeze(1)
|
||||
greedy_mask = is_greedy.unsqueeze(1)
|
||||
final_copy_mask = copy_mask & greedy_mask
|
||||
global_idx = start_indices.unsqueeze(1) + copy_indices
|
||||
output_token_ids[final_copy_mask] = target_argmax[
|
||||
global_idx[final_copy_mask]].to(output_token_ids.dtype)
|
||||
output_token_ids[final_copy_mask] = target_argmax[global_idx[final_copy_mask]].to(
|
||||
output_token_ids.dtype
|
||||
)
|
||||
# Fill bonus token.
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req
|
||||
>= draft_tokens_per_req)
|
||||
needs_bonus = is_greedy & (first_mismatch_pos_per_req >= draft_tokens_per_req)
|
||||
if torch.any(needs_bonus):
|
||||
bonus_rows = torch.where(needs_bonus)[0]
|
||||
bonus_cols = draft_tokens_per_req[bonus_rows]
|
||||
@@ -556,11 +544,9 @@ def rejection_random_sample_pytorch(
|
||||
if IS_NGRAM:
|
||||
draft_prob = 1.0
|
||||
else:
|
||||
draft_prob = draft_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
draft_prob = draft_probs[start_idx + pos, draft_token_id].item()
|
||||
|
||||
target_prob = target_probs[start_idx + pos,
|
||||
draft_token_id].item()
|
||||
target_prob = target_probs[start_idx + pos, draft_token_id].item()
|
||||
uniform_prob = uniform_probs[start_idx + pos].item()
|
||||
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
@@ -629,12 +615,11 @@ def sample_recovered_tokens_pytorch(
|
||||
else:
|
||||
draft_p = draft_probs[token_idx].clone()
|
||||
target_p = target_probs[token_idx].clone()
|
||||
prob = torch.maximum(target_p - draft_p,
|
||||
torch.tensor(0.0, device=target_p.device))
|
||||
prob = torch.maximum(
|
||||
target_p - draft_p, torch.tensor(0.0, device=target_p.device)
|
||||
)
|
||||
|
||||
q_values = torch.full((vocab_size, ),
|
||||
float('-inf'),
|
||||
device=q.device)
|
||||
q_values = torch.full((vocab_size,), float("-inf"), device=q.device)
|
||||
q_values[:vocab_size] = q[req_idx, :vocab_size]
|
||||
|
||||
recovered_id = torch.argmax(prob / q_values).item()
|
||||
@@ -642,4 +627,3 @@ def sample_recovered_tokens_pytorch(
|
||||
|
||||
if IS_NGRAM:
|
||||
target_probs[token_idx, draft_token_id] = orig_prob
|
||||
|
||||
|
||||
@@ -337,5 +337,5 @@ def prepare_next_token_ids_padded(
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
|
||||
EagleProposer.propose = propose
|
||||
# EagleProposer.propose = propose
|
||||
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
||||
|
||||
@@ -386,8 +386,8 @@ def silu_and_mul_quant_xpu(
|
||||
pass
|
||||
|
||||
|
||||
import kunlun_ops # noqa: E402
|
||||
import torch # noqa: E402
|
||||
import xtorch_ops # noqa: E402
|
||||
from torch.library import custom_op, impl # noqa: E402
|
||||
|
||||
|
||||
@@ -405,9 +405,9 @@ def add_rmsnorm(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.add_rmsnorm(
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x,
|
||||
y, # 原来写 residual,这里其实是 y
|
||||
y,
|
||||
residual_output=residual_output,
|
||||
weight=weight,
|
||||
eps=eps,
|
||||
@@ -429,7 +429,7 @@ def add_rmsnorm_cuda(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.add_rmsnorm(
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
residual_output=residual_output,
|
||||
@@ -451,7 +451,7 @@ def rmsnorm(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.rmsnorm(
|
||||
kunlun_ops.rmsnorm(
|
||||
x,
|
||||
weight,
|
||||
output,
|
||||
@@ -471,7 +471,7 @@ def rmsnorm_cuda(
|
||||
residual_output: torch.Tensor = None,
|
||||
output_max: torch.Tensor = None,
|
||||
) -> None:
|
||||
xtorch_ops.rmsnorm(
|
||||
kunlun_ops.rmsnorm(
|
||||
x,
|
||||
weight,
|
||||
output,
|
||||
@@ -523,6 +523,145 @@ def _fake_add_rmsnorm(
|
||||
add_rmsnorm.register_fake(_fake_add_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
|
||||
def gemma_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_add_rmsnorm wrapper")
|
||||
kunlun_ops.gemma_add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
weight=weight,
|
||||
output=output,
|
||||
eps=eps,
|
||||
enable_pdl=enable_pdl,
|
||||
interweaved=interweaved,
|
||||
store_output_before_norm=store_output_before_norm,
|
||||
bias=bias,
|
||||
smooth=smooth,
|
||||
residual_output=residual_output,
|
||||
force_sdnn=force_sdnn,
|
||||
)
|
||||
|
||||
|
||||
@impl("_C::gemma_add_rmsnorm", "CUDA")
|
||||
def gemma_add_rmsnorm_cuda(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_add_rmsnorm_cuda wrapper")
|
||||
kunlun_ops.gemma_add_rmsnorm(
|
||||
x,
|
||||
y,
|
||||
weight=weight,
|
||||
output=output,
|
||||
eps=eps,
|
||||
enable_pdl=enable_pdl,
|
||||
interweaved=interweaved,
|
||||
store_output_before_norm=store_output_before_norm,
|
||||
bias=bias,
|
||||
smooth=smooth,
|
||||
residual_output=residual_output,
|
||||
force_sdnn=force_sdnn,
|
||||
)
|
||||
|
||||
|
||||
def _fake_gemma_add_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweaved: bool = False,
|
||||
store_output_before_norm: bool = True,
|
||||
bias: torch.Tensor = None,
|
||||
smooth: torch.Tensor = None,
|
||||
residual_output: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
):
|
||||
output.fake_shape = x.shape
|
||||
output.fake_dtype = x.dtype
|
||||
return None
|
||||
|
||||
|
||||
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::gemma_rmsnorm", mutates_args=())
|
||||
def gemma_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_rmsnorm wrapper")
|
||||
kunlun_ops.gemma_rmsnorm(
|
||||
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||
)
|
||||
|
||||
|
||||
@impl("_C::gemma_rmsnorm", "CUDA")
|
||||
def gemma_rmsnorm_cuda(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
# print("gemma_rmsnorm_cuda wrapper")
|
||||
kunlun_ops.gemma_rmsnorm(
|
||||
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
||||
)
|
||||
|
||||
|
||||
def _fake_gemma_rmsnorm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
enable_pdl: bool = False,
|
||||
interweave: bool = False,
|
||||
bias: torch.Tensor = None,
|
||||
force_sdnn: bool = False,
|
||||
):
|
||||
# 设置 shape/dtype,但不返回值
|
||||
output.fake_shape = x.shape
|
||||
output.fake_dtype = x.dtype
|
||||
return None
|
||||
|
||||
|
||||
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
|
||||
|
||||
|
||||
@custom_op("_C::split_norm_rope_neox", mutates_args=())
|
||||
def split_norm_rope_neox(
|
||||
q_emb: torch.Tensor,
|
||||
@@ -541,7 +680,7 @@ def split_norm_rope_neox(
|
||||
rotary_dim: int,
|
||||
emb_batch_size: int = 1,
|
||||
) -> None:
|
||||
xtorch_ops.split_norm_rope_neox(
|
||||
kunlun_ops.split_norm_rope_neox(
|
||||
q_emb,
|
||||
k_emb,
|
||||
v_out,
|
||||
@@ -577,7 +716,7 @@ def split_norm_rope_neox_cuda(
|
||||
rotary_dim: int,
|
||||
emb_batch_size: int = 1,
|
||||
) -> None:
|
||||
xtorch_ops.split_norm_rope_neox(
|
||||
kunlun_ops.split_norm_rope_neox(
|
||||
q_emb,
|
||||
k_emb,
|
||||
v_out,
|
||||
@@ -649,7 +788,7 @@ if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
||||
def silu_and_mul(
|
||||
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
||||
) -> None:
|
||||
xtorch_ops.swiglu(
|
||||
kunlun_ops.swiglu(
|
||||
x=x,
|
||||
y=out,
|
||||
)
|
||||
@@ -659,7 +798,7 @@ def silu_and_mul(
|
||||
def silu_and_mul_cuda(
|
||||
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
||||
) -> None:
|
||||
xtorch_ops.swiglu(
|
||||
kunlun_ops.swiglu(
|
||||
x=x,
|
||||
y=out,
|
||||
)
|
||||
@@ -736,7 +875,7 @@ def moe_softmax_topk(
|
||||
axis: int = -1,
|
||||
turn: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
|
||||
|
||||
@impl("_C::moe_softmax_topk", "CUDA")
|
||||
@@ -748,7 +887,7 @@ def moe_softmax_topk_cuda(
|
||||
axis: int = -1,
|
||||
turn: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||
|
||||
|
||||
def _fake_moe_softmax_topk(
|
||||
@@ -781,7 +920,7 @@ def moe_ffn_block(
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_block(
|
||||
kunlun_ops.moe_ffn_block(
|
||||
x=x,
|
||||
gate_w=gate_w,
|
||||
inter_w=inter_w,
|
||||
@@ -812,7 +951,7 @@ def moe_ffn_block_cuda(
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_block(
|
||||
kunlun_ops.moe_ffn_block(
|
||||
x=x,
|
||||
gate_w=gate_w,
|
||||
inter_w=inter_w,
|
||||
@@ -863,7 +1002,7 @@ def moe_ffn_per_token_block(
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_per_token_block(
|
||||
kunlun_ops.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=inter_weight,
|
||||
inter_scale=inter_scale,
|
||||
@@ -897,7 +1036,7 @@ def moe_ffn_per_token_block_cuda(
|
||||
ep_size: int = 1,
|
||||
ep_rank: int = 0,
|
||||
) -> None:
|
||||
xtorch_ops.moe_ffn_per_token_block(
|
||||
kunlun_ops.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=inter_weight,
|
||||
inter_scale=inter_scale,
|
||||
@@ -948,7 +1087,7 @@ def rotary_embedding(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
xtorch_ops.rotary_embedding(
|
||||
kunlun_ops.rotary_embedding(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -967,7 +1106,7 @@ def rotary_embedding_cuda(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
xtorch_ops.rotary_embedding(
|
||||
kunlun_ops.rotary_embedding(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -999,7 +1138,7 @@ def gemm_I8_I8_bf16_nt(
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
kunlun_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
||||
)
|
||||
|
||||
@@ -1012,7 +1151,7 @@ def gemm_I8_I8_bf16_nt_cuda(
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
||||
kunlun_ops.gemm_I8_I8_bf16_nt(
|
||||
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
||||
)
|
||||
|
||||
@@ -1038,7 +1177,7 @@ def moe_softmax_topk_norm(
|
||||
block_statistic: torch.Tensor,
|
||||
stable: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk_norm(
|
||||
kunlun_ops.moe_softmax_topk_norm(
|
||||
x, normed_score, topk_index, block_statistic, stable
|
||||
)
|
||||
|
||||
@@ -1051,7 +1190,7 @@ def moe_softmax_topk_norm_cuda(
|
||||
block_statistic: torch.Tensor,
|
||||
stable: bool = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_softmax_topk_norm(
|
||||
kunlun_ops.moe_softmax_topk_norm(
|
||||
x, normed_score, topk_index, block_statistic, stable
|
||||
)
|
||||
|
||||
@@ -1071,14 +1210,14 @@ moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
|
||||
|
||||
@custom_op("_C::gen_block_statistic", mutates_args=())
|
||||
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
|
||||
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
|
||||
|
||||
@impl("_C::gen_block_statistic", "CUDA")
|
||||
def gen_block_statistic_cuda(
|
||||
topk_ids: torch.Tensor, block_statistic: torch.Tensor
|
||||
) -> None:
|
||||
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||
|
||||
|
||||
def fake_gen_block_statistic(
|
||||
@@ -1101,7 +1240,7 @@ def moe_pre_sorted(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
index_have_neg: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.moe_pre_sorted(
|
||||
kunlun_ops.moe_pre_sorted(
|
||||
x,
|
||||
topk_index,
|
||||
block_statistic,
|
||||
@@ -1123,7 +1262,7 @@ def moe_pre_sorted_cuda(
|
||||
sorted_tokens_num_lod: torch.Tensor,
|
||||
index_have_neg: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.moe_pre_sorted(
|
||||
kunlun_ops.moe_pre_sorted(
|
||||
x,
|
||||
topk_index,
|
||||
block_statistic,
|
||||
@@ -1171,7 +1310,7 @@ def moe_fc(
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_fc(
|
||||
kunlun_ops.moe_fc(
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
@@ -1214,7 +1353,7 @@ def moe_fc_cuda(
|
||||
use_pack_int4: Optional[bool] = False,
|
||||
sort_mode: Optional[bool] = True,
|
||||
) -> None:
|
||||
xtorch_ops.moe_fc(
|
||||
kunlun_ops.moe_fc(
|
||||
x=x,
|
||||
weight=weight,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
@@ -1270,7 +1409,7 @@ def moe_post(
|
||||
dequant_scale: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
|
||||
|
||||
@impl("_C::moe_post", "CUDA")
|
||||
@@ -1281,7 +1420,7 @@ def moe_post_cuda(
|
||||
dequant_scale: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> None:
|
||||
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||
|
||||
|
||||
def fake_moe_post(
|
||||
@@ -1308,7 +1447,7 @@ def moe_sigmoid_group_topk_norm(
|
||||
n_group: int,
|
||||
topk_group: int,
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
kunlun_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
@@ -1331,7 +1470,7 @@ def moe_sigmoid_group_topk_norm_cuda(
|
||||
n_group: int,
|
||||
topk_group: int,
|
||||
) -> None:
|
||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
||||
kunlun_ops.moe_sigmoid_group_topk_norm(
|
||||
x=x,
|
||||
norm_score=norm_score,
|
||||
topk_index=topk_index,
|
||||
@@ -1376,7 +1515,7 @@ def awq_dequantize(
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
xtorch_ops.awq_dequantize(
|
||||
kunlun_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
@@ -1402,7 +1541,7 @@ def awq_dequantize_cuda(
|
||||
device=qweight.device,
|
||||
)
|
||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||
xtorch_ops.awq_dequantize(
|
||||
kunlun_ops.awq_dequantize(
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
zeros=zeros,
|
||||
@@ -1447,7 +1586,7 @@ def awq_gemm(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||
xtorch_ops.awq_gemm(
|
||||
kunlun_ops.awq_gemm(
|
||||
x=x,
|
||||
w=qweight,
|
||||
scale=scale,
|
||||
@@ -1471,7 +1610,7 @@ def awq_gemm_cuda(
|
||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||
)
|
||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||
xtorch_ops.awq_gemm(
|
||||
kunlun_ops.awq_gemm(
|
||||
x=x,
|
||||
w=qweight,
|
||||
scale=scale,
|
||||
@@ -1508,7 +1647,7 @@ def gptq_shuffle(
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
|
||||
|
||||
@impl("_C::gptq_shuffle", "CUDA")
|
||||
@@ -1517,7 +1656,7 @@ def gptq_shuffle_cuda(
|
||||
q_perm: torch.Tensor,
|
||||
bit: int,
|
||||
) -> None:
|
||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||
|
||||
|
||||
def _fake_gptq_shuffle(
|
||||
@@ -1541,7 +1680,7 @@ def concat_and_cache_mla(
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
@@ -1556,7 +1695,7 @@ def concat_and_cache_mla_cuda(
|
||||
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
||||
) -> None:
|
||||
xtorch_ops.concat_and_cache_mla(
|
||||
kunlun_ops.concat_and_cache_mla(
|
||||
kv_c=kv_c,
|
||||
k_pe=k_pe,
|
||||
slot_mapping=slot_mapping,
|
||||
@@ -1598,7 +1737,7 @@ def scaled_int8_quant(
|
||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||
if symmetric:
|
||||
# NOTE: For quant2d ops, scale represents max.
|
||||
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
else:
|
||||
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
||||
x_q, x.contiguous(), scale, azp
|
||||
@@ -1625,7 +1764,7 @@ def scaled_int8_quant_cuda(
|
||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||
if symmetric:
|
||||
# NOTE: For quant2d ops, scale represents max.
|
||||
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||
else:
|
||||
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
||||
x_q, x.contiguous(), scale, azp
|
||||
@@ -1777,7 +1916,7 @@ def matmul(
|
||||
dtype=out_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=out,
|
||||
@@ -1814,7 +1953,7 @@ def matmul_cuda(
|
||||
dtype=out_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=out,
|
||||
@@ -1865,7 +2004,7 @@ def quant2d(
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
kunlun_ops.quant2d(
|
||||
x=x,
|
||||
y=x_q,
|
||||
max=max,
|
||||
@@ -1880,7 +2019,7 @@ def quant2d_cuda(
|
||||
max: torch.Tensor,
|
||||
force_sdnn: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.quant2d(
|
||||
kunlun_ops.quant2d(
|
||||
x=x,
|
||||
y=x_q,
|
||||
max=max,
|
||||
@@ -1954,7 +2093,7 @@ def I8_mqa_logits(
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
kunlun_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -1984,7 +2123,7 @@ def I8_mqa_logits_cuda(
|
||||
is_causal: Optional[bool] = False,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_mqa_logits(
|
||||
kunlun_ops.I8_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2034,7 +2173,7 @@ def I8_paged_mqa_logits(
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
kunlun_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2060,7 +2199,7 @@ def I8_paged_mqa_logits_cuda(
|
||||
out: torch.Tensor,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.I8_paged_mqa_logits(
|
||||
kunlun_ops.I8_paged_mqa_logits(
|
||||
q=q,
|
||||
fused_kv_cache=fused_kv_cache,
|
||||
weights=weights,
|
||||
@@ -2111,7 +2250,7 @@ def sparse_prefill_fwd_opt(
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
kunlun_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
@@ -2147,7 +2286,7 @@ def sparse_prefill_fwd_opt_cuda(
|
||||
is_causal: Optional[bool] = True,
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
) -> None:
|
||||
xtorch_ops.sparse_prefill_fwd_opt(
|
||||
kunlun_ops.sparse_prefill_fwd_opt(
|
||||
q=q,
|
||||
kv=kv,
|
||||
indices=indices,
|
||||
@@ -2207,7 +2346,7 @@ def fwd_kvcache_mla(
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
kunlun_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
indices=indices,
|
||||
@@ -2241,7 +2380,7 @@ def fwd_kvcache_mla_cuda(
|
||||
use_xfa_boost: Optional[bool] = False,
|
||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
xtorch_ops.fwd_kvcache_mla(
|
||||
kunlun_ops.fwd_kvcache_mla(
|
||||
q_c=q_c,
|
||||
kv_cache=kv_cache,
|
||||
indices=indices,
|
||||
@@ -2293,7 +2432,7 @@ def dequant_int4(
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.dequant_int4(
|
||||
kunlun_ops.dequant_int4(
|
||||
x=x,
|
||||
scale=scale,
|
||||
zero=zero,
|
||||
@@ -2315,7 +2454,7 @@ def dequant_int4_cuda(
|
||||
int4_signed: bool = True,
|
||||
use_mode_fast: bool = False,
|
||||
) -> None:
|
||||
xtorch_ops.dequant_int4(
|
||||
kunlun_ops.dequant_int4(
|
||||
x=x,
|
||||
scale=scale,
|
||||
zero=zero,
|
||||
@@ -2350,7 +2489,7 @@ def fast_topkv2(
|
||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||
) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
return topk_indices
|
||||
|
||||
|
||||
@@ -2359,7 +2498,7 @@ def fast_topkv2_cuda(
|
||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||
) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
||||
return topk_indices
|
||||
|
||||
|
||||
@@ -2798,7 +2937,7 @@ def lora_matmul_inplace(
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
) -> None:
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=output_tensor,
|
||||
@@ -2819,7 +2958,7 @@ def lora_matmul_inplace_cuda(
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
) -> None:
|
||||
xtorch_ops.matmul(
|
||||
kunlun_ops.matmul(
|
||||
x=x.contiguous(),
|
||||
w=w.contiguous(),
|
||||
out=output_tensor,
|
||||
|
||||
Reference in New Issue
Block a user