[Doc] Steps to add a new attention backend (#8155)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
part: [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
timeout-minutes: 30
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9
|
||||
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 10
|
||||
|
||||
unit-test-backend-2-gpu:
|
||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||
|
||||
@@ -52,3 +52,31 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
||||
```
|
||||
|
||||
|
||||
## Steps to add a new attention backend
|
||||
To add a new attention backend, you can learn from the existing backends
|
||||
(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)
|
||||
and follow the steps below.
|
||||
|
||||
1. Run without cuda graph. Support the two forward functions
|
||||
- forward_extend
|
||||
- Will be used for prefill, prefill with KV cache, and target verification
|
||||
- It will be called once per layer
|
||||
- forward_decode
|
||||
- Will be used for normal decode, and draft decode
|
||||
- It will be called once per layer
|
||||
- init_forward_metadata
|
||||
- Initialize the class and common metadata shared by all layers
|
||||
- Call the plan function for optimizations like split_kv
|
||||
- It will be called once per forward
|
||||
2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions
|
||||
- init_cuda_graph_state
|
||||
- It will be called once during life time
|
||||
- Create all common shared buffers
|
||||
- init_forward_metadata_capture_cuda_graph
|
||||
- It will be called before capturing a cuda graph
|
||||
- It is similar to init_forward_metadata but write the medatada to some pre-defined buffers
|
||||
- init_forward_metadata_replay_cuda_graph
|
||||
- It will be called before replaying a cuda graph
|
||||
- This function is in the critical path and needs to be fast
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
# ==============================================================================
|
||||
"""
|
||||
The definition of objects transferred between different
|
||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
processes (TokenizerManager, DetokenizerManager, Scheduler).
|
||||
"""
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||
@@ -545,7 +545,7 @@ class EmbeddingReqInput:
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||
# Dummy input embeds for compatibility
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||
@@ -953,17 +953,6 @@ class ProfileReqType(Enum):
|
||||
STOP_PROFILE = 2
|
||||
|
||||
|
||||
class ExpertDistributionReq(Enum):
|
||||
START_RECORD = 1
|
||||
STOP_RECORD = 2
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertDistributionReqOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileReq:
|
||||
type: ProfileReqType
|
||||
@@ -1013,6 +1002,17 @@ class HealthCheckOutput:
|
||||
pass
|
||||
|
||||
|
||||
class ExpertDistributionReq(Enum):
|
||||
START_RECORD = 1
|
||||
STOP_RECORD = 2
|
||||
DUMP_RECORD = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpertDistributionReqOutput:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
description: Optional[str] = None
|
||||
|
||||
@@ -155,11 +155,11 @@ suites = {
|
||||
"per-commit-2-gpu": [
|
||||
TestFile("models/lora/test_lora_tp.py", 116),
|
||||
TestFile("test_data_parallelism.py", 73),
|
||||
TestFile("test_dp_attention.py", 137),
|
||||
TestFile("test_dp_attention.py", 277),
|
||||
TestFile("test_mla_tp.py", 170),
|
||||
TestFile("test_patch_torch.py", 19),
|
||||
TestFile("test_update_weights_from_distributed.py", 103),
|
||||
TestFile("test_release_memory_occupation.py", 44),
|
||||
TestFile("test_release_memory_occupation.py", 127),
|
||||
],
|
||||
"per-commit-2-gpu-amd": [
|
||||
TestFile("models/lora/test_lora_tp.py", 116),
|
||||
@@ -170,7 +170,7 @@ suites = {
|
||||
],
|
||||
"per-commit-4-gpu": [
|
||||
TestFile("test_local_attn.py", 250),
|
||||
TestFile("test_pp_single_node.py", 150),
|
||||
TestFile("test_pp_single_node.py", 372),
|
||||
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
||||
],
|
||||
"per-commit-4-gpu-deepep": [
|
||||
@@ -182,12 +182,12 @@ suites = {
|
||||
"per-commit-8-gpu": [
|
||||
# Disabled because it hangs on the CI.
|
||||
# TestFile("test_moe_ep.py", 181),
|
||||
TestFile("test_disaggregation.py", 270),
|
||||
TestFile("test_disaggregation.py", 499),
|
||||
TestFile("test_disaggregation_different_tp.py", 155),
|
||||
TestFile("test_full_deepseek_v3.py", 463),
|
||||
TestFile("test_full_deepseek_v3.py", 333),
|
||||
],
|
||||
"per-commit-8-gpu-deepep": [
|
||||
TestFile("test_deepep_large.py", 485),
|
||||
TestFile("test_deepep_large.py", 338),
|
||||
],
|
||||
"per-commit-8-gpu-amd": [
|
||||
TestFile("test_full_deepseek_v3.py", 250),
|
||||
@@ -214,11 +214,11 @@ suites = {
|
||||
TestFile("test_nightly_gsm8k_eval_amd.py"),
|
||||
],
|
||||
"vllm_dependency_test": [
|
||||
TestFile("test_awq.py"),
|
||||
TestFile("test_bnb.py"),
|
||||
TestFile("test_gguf.py", 78),
|
||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
||||
TestFile("test_vllm_dependency.py"),
|
||||
TestFile("test_awq.py", 163),
|
||||
TestFile("test_bnb.py", 5),
|
||||
TestFile("test_gguf.py", 96),
|
||||
TestFile("test_gptqmodel_dynamic.py", 102),
|
||||
TestFile("test_vllm_dependency.py", 185),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user