[Doc] Steps to add a new attention backend (#8155)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
part: [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
part: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -69,7 +69,7 @@ jobs:
|
|||||||
timeout-minutes: 30
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 9
|
python3 run_suite.py --suite per-commit --auto-partition-id ${{ matrix.part }} --auto-partition-size 10
|
||||||
|
|
||||||
unit-test-backend-2-gpu:
|
unit-test-backend-2-gpu:
|
||||||
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
|
||||||
|
|||||||
@@ -52,3 +52,31 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
|
|||||||
```bash
|
```bash
|
||||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Steps to add a new attention backend
|
||||||
|
To add a new attention backend, you can learn from the existing backends
|
||||||
|
(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)
|
||||||
|
and follow the steps below.
|
||||||
|
|
||||||
|
1. Run without cuda graph. Support the two forward functions
|
||||||
|
- forward_extend
|
||||||
|
- Will be used for prefill, prefill with KV cache, and target verification
|
||||||
|
- It will be called once per layer
|
||||||
|
- forward_decode
|
||||||
|
- Will be used for normal decode, and draft decode
|
||||||
|
- It will be called once per layer
|
||||||
|
- init_forward_metadata
|
||||||
|
- Initialize the class and common metadata shared by all layers
|
||||||
|
- Call the plan function for optimizations like split_kv
|
||||||
|
- It will be called once per forward
|
||||||
|
2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions
|
||||||
|
- init_cuda_graph_state
|
||||||
|
- It will be called once during life time
|
||||||
|
- Create all common shared buffers
|
||||||
|
- init_forward_metadata_capture_cuda_graph
|
||||||
|
- It will be called before capturing a cuda graph
|
||||||
|
- It is similar to init_forward_metadata but write the medatada to some pre-defined buffers
|
||||||
|
- init_forward_metadata_replay_cuda_graph
|
||||||
|
- It will be called before replaying a cuda graph
|
||||||
|
- This function is in the critical path and needs to be fast
|
||||||
|
|||||||
@@ -13,14 +13,14 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""
|
"""
|
||||||
The definition of objects transferred between different
|
The definition of objects transferred between different
|
||||||
processes (TokenizerManager, DetokenizerManager, Controller).
|
processes (TokenizerManager, DetokenizerManager, Scheduler).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||||
@@ -545,7 +545,7 @@ class EmbeddingReqInput:
|
|||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Dummy sampling params for compatibility
|
# Dummy sampling params for compatibility
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
||||||
# Dummy input embeds for compatibility
|
# Dummy input embeds for compatibility
|
||||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||||
@@ -953,17 +953,6 @@ class ProfileReqType(Enum):
|
|||||||
STOP_PROFILE = 2
|
STOP_PROFILE = 2
|
||||||
|
|
||||||
|
|
||||||
class ExpertDistributionReq(Enum):
|
|
||||||
START_RECORD = 1
|
|
||||||
STOP_RECORD = 2
|
|
||||||
DUMP_RECORD = 3
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExpertDistributionReqOutput:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProfileReq:
|
class ProfileReq:
|
||||||
type: ProfileReqType
|
type: ProfileReqType
|
||||||
@@ -1013,6 +1002,17 @@ class HealthCheckOutput:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertDistributionReq(Enum):
|
||||||
|
START_RECORD = 1
|
||||||
|
STOP_RECORD = 2
|
||||||
|
DUMP_RECORD = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpertDistributionReqOutput:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Function:
|
class Function:
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|||||||
@@ -155,11 +155,11 @@ suites = {
|
|||||||
"per-commit-2-gpu": [
|
"per-commit-2-gpu": [
|
||||||
TestFile("models/lora/test_lora_tp.py", 116),
|
TestFile("models/lora/test_lora_tp.py", 116),
|
||||||
TestFile("test_data_parallelism.py", 73),
|
TestFile("test_data_parallelism.py", 73),
|
||||||
TestFile("test_dp_attention.py", 137),
|
TestFile("test_dp_attention.py", 277),
|
||||||
TestFile("test_mla_tp.py", 170),
|
TestFile("test_mla_tp.py", 170),
|
||||||
TestFile("test_patch_torch.py", 19),
|
TestFile("test_patch_torch.py", 19),
|
||||||
TestFile("test_update_weights_from_distributed.py", 103),
|
TestFile("test_update_weights_from_distributed.py", 103),
|
||||||
TestFile("test_release_memory_occupation.py", 44),
|
TestFile("test_release_memory_occupation.py", 127),
|
||||||
],
|
],
|
||||||
"per-commit-2-gpu-amd": [
|
"per-commit-2-gpu-amd": [
|
||||||
TestFile("models/lora/test_lora_tp.py", 116),
|
TestFile("models/lora/test_lora_tp.py", 116),
|
||||||
@@ -170,7 +170,7 @@ suites = {
|
|||||||
],
|
],
|
||||||
"per-commit-4-gpu": [
|
"per-commit-4-gpu": [
|
||||||
TestFile("test_local_attn.py", 250),
|
TestFile("test_local_attn.py", 250),
|
||||||
TestFile("test_pp_single_node.py", 150),
|
TestFile("test_pp_single_node.py", 372),
|
||||||
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
TestFile("test_multi_instance_release_memory_occupation.py", 64),
|
||||||
],
|
],
|
||||||
"per-commit-4-gpu-deepep": [
|
"per-commit-4-gpu-deepep": [
|
||||||
@@ -182,12 +182,12 @@ suites = {
|
|||||||
"per-commit-8-gpu": [
|
"per-commit-8-gpu": [
|
||||||
# Disabled because it hangs on the CI.
|
# Disabled because it hangs on the CI.
|
||||||
# TestFile("test_moe_ep.py", 181),
|
# TestFile("test_moe_ep.py", 181),
|
||||||
TestFile("test_disaggregation.py", 270),
|
TestFile("test_disaggregation.py", 499),
|
||||||
TestFile("test_disaggregation_different_tp.py", 155),
|
TestFile("test_disaggregation_different_tp.py", 155),
|
||||||
TestFile("test_full_deepseek_v3.py", 463),
|
TestFile("test_full_deepseek_v3.py", 333),
|
||||||
],
|
],
|
||||||
"per-commit-8-gpu-deepep": [
|
"per-commit-8-gpu-deepep": [
|
||||||
TestFile("test_deepep_large.py", 485),
|
TestFile("test_deepep_large.py", 338),
|
||||||
],
|
],
|
||||||
"per-commit-8-gpu-amd": [
|
"per-commit-8-gpu-amd": [
|
||||||
TestFile("test_full_deepseek_v3.py", 250),
|
TestFile("test_full_deepseek_v3.py", 250),
|
||||||
@@ -214,11 +214,11 @@ suites = {
|
|||||||
TestFile("test_nightly_gsm8k_eval_amd.py"),
|
TestFile("test_nightly_gsm8k_eval_amd.py"),
|
||||||
],
|
],
|
||||||
"vllm_dependency_test": [
|
"vllm_dependency_test": [
|
||||||
TestFile("test_awq.py"),
|
TestFile("test_awq.py", 163),
|
||||||
TestFile("test_bnb.py"),
|
TestFile("test_bnb.py", 5),
|
||||||
TestFile("test_gguf.py", 78),
|
TestFile("test_gguf.py", 96),
|
||||||
TestFile("test_gptqmodel_dynamic.py", 72),
|
TestFile("test_gptqmodel_dynamic.py", 102),
|
||||||
TestFile("test_vllm_dependency.py"),
|
TestFile("test_vllm_dependency.py", 185),
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user