feat: support flashinfer mla attention for deepseek v3 (#3550)
This commit is contained in:
16
.github/workflows/pr-test.yml
vendored
16
.github/workflows/pr-test.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
@@ -98,7 +98,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
@@ -163,7 +163,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
@@ -209,7 +209,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
@@ -243,7 +243,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
@@ -283,7 +283,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
git clone https://github.com/merrymercy/human-eval.git
|
git clone https://github.com/merrymercy/human-eval.git
|
||||||
@@ -308,7 +308,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
env:
|
env:
|
||||||
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }}
|
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
|
||||||
run: |
|
run: |
|
||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
git clone https://github.com/merrymercy/human-eval.git
|
git clone https://github.com/merrymercy/human-eval.git
|
||||||
|
|||||||
@@ -21,12 +21,13 @@ runtime_common = [
|
|||||||
"hf_transfer", "huggingface_hub", "interegular", "modelscope",
|
"hf_transfer", "huggingface_hub", "interegular", "modelscope",
|
||||||
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0",
|
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0",
|
||||||
"psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2",
|
"psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2",
|
||||||
"torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10"
|
"torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10", "ninja"
|
||||||
]
|
]
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]", "cuda-python",
|
"sglang[runtime_common]", "cuda-python",
|
||||||
"sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2",
|
"sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2",
|
||||||
"flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11"
|
"flashinfer_python>=0.2.1.post1",
|
||||||
|
"outlines>=0.0.44,<=0.1.11",
|
||||||
]
|
]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
|
|||||||
@@ -38,5 +38,7 @@ class GlobalConfig:
|
|||||||
self.enable_precache_with_tracing = True
|
self.enable_precache_with_tracing = True
|
||||||
self.enable_parallel_encoding = True
|
self.enable_parallel_encoding = True
|
||||||
|
|
||||||
|
self.enable_flashinfer_mla = False
|
||||||
|
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if server_args.attention_backend == "flashinfer":
|
if server_args.attention_backend == "flashinfer":
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer_python",
|
"flashinfer_python",
|
||||||
"0.2.0.post2",
|
"0.2.1.post1",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|||||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
@@ -20,6 +21,7 @@ import triton.language as tl
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention import AttentionBackend
|
from sglang.srt.layers.attention import AttentionBackend
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
@@ -35,7 +37,7 @@ if is_flashinfer_available():
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.decode import PosEncodingMode
|
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodeMetadata:
|
class DecodeMetadata:
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
decode_wrappers: List[
|
||||||
|
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||||
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
||||||
|
|
||||||
|
self.enable_flashinfer_mla = False
|
||||||
|
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
||||||
|
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||||
|
self.enable_flashinfer_mla = True
|
||||||
|
global_config.enable_flashinfer_mla = True
|
||||||
|
|
||||||
# Allocate buffers
|
# Allocate buffers
|
||||||
global global_workspace_buffer
|
global global_workspace_buffer
|
||||||
if global_workspace_buffer is None:
|
if global_workspace_buffer is None:
|
||||||
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
for _ in range(self.num_wrappers)
|
for _ in range(self.num_wrappers)
|
||||||
]
|
]
|
||||||
|
if self.enable_flashinfer_mla:
|
||||||
|
self.qo_indptr = [
|
||||||
|
torch.zeros(
|
||||||
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
for _ in range(self.num_wrappers)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
assert self.num_wrappers == 1
|
assert self.num_wrappers == 1
|
||||||
self.kv_indptr = [kv_indptr_buf]
|
self.kv_indptr = [kv_indptr_buf]
|
||||||
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.prefill_wrappers_verify.append(
|
self.prefill_wrappers_verify.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
)
|
)
|
||||||
self.decode_wrappers.append(
|
if self.enable_flashinfer_mla:
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
self.decode_wrappers.append(
|
||||||
self.workspace_buffer,
|
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
||||||
"NHD",
|
)
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
else:
|
||||||
|
self.decode_wrappers.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Create indices updater
|
# Create indices updater
|
||||||
if not skip_prefill:
|
if not skip_prefill:
|
||||||
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
for i in range(self.num_wrappers):
|
for i in range(self.num_wrappers):
|
||||||
decode_wrappers.append(
|
if self.enable_flashinfer_mla:
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
decode_wrappers.append(
|
||||||
self.workspace_buffer,
|
BatchMLAPagedAttentionWrapper(
|
||||||
"NHD",
|
self.workspace_buffer,
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
|
||||||
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
|
||||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
kv_indices=self.cuda_graph_kv_indices[i],
|
||||||
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
||||||
:num_tokens
|
backend="fa2",
|
||||||
],
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decode_wrappers.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_cuda_graph=True,
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
|
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
||||||
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||||
|
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
||||||
|
:num_tokens
|
||||||
|
],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
seq_lens_sum = seq_lens.sum().item()
|
seq_lens_sum = seq_lens.sum().item()
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
):
|
):
|
||||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
if global_config.enable_flashinfer_mla:
|
||||||
self._get_wrapper_idx(layer)
|
cache_loc = (
|
||||||
]
|
forward_batch.out_cache_loc
|
||||||
cache_loc = (
|
if not layer.is_cross_attention
|
||||||
forward_batch.out_cache_loc
|
else forward_batch.encoder_out_cache_loc
|
||||||
if not layer.is_cross_attention
|
|
||||||
else forward_batch.encoder_out_cache_loc
|
|
||||||
)
|
|
||||||
|
|
||||||
logits_soft_cap = layer.logit_cap
|
|
||||||
|
|
||||||
if not self.forward_metadata.use_ragged:
|
|
||||||
if k is not None:
|
|
||||||
assert v is not None
|
|
||||||
if save_kv_cache:
|
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
o = prefill_wrapper_paged.forward(
|
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
|
||||||
causal=not layer.is_cross_attention,
|
|
||||||
sm_scale=layer.scaling,
|
|
||||||
window_left=layer.sliding_window_size,
|
|
||||||
logits_soft_cap=logits_soft_cap,
|
|
||||||
k_scale=layer.k_scale,
|
|
||||||
v_scale=layer.v_scale,
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
logits_soft_cap = layer.logit_cap
|
||||||
|
|
||||||
|
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.forward_metadata.extend_no_prefix:
|
o = o1
|
||||||
o = o1
|
|
||||||
else:
|
|
||||||
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
|
||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
|
||||||
causal=False,
|
|
||||||
sm_scale=layer.scaling,
|
|
||||||
logits_soft_cap=layer.logit_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
o, _ = merge_state(o1, s1, o2, s2)
|
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
else:
|
||||||
|
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||||
|
self._get_wrapper_idx(layer)
|
||||||
|
]
|
||||||
|
cache_loc = (
|
||||||
|
forward_batch.out_cache_loc
|
||||||
|
if not layer.is_cross_attention
|
||||||
|
else forward_batch.encoder_out_cache_loc
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_soft_cap = layer.logit_cap
|
||||||
|
|
||||||
|
if not self.forward_metadata.use_ragged:
|
||||||
|
if k is not None:
|
||||||
|
assert v is not None
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
o = prefill_wrapper_paged.forward(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
causal=not layer.is_cross_attention,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
window_left=layer.sliding_window_size,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
k_scale=layer.k_scale,
|
||||||
|
v_scale=layer.v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||||
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
||||||
|
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
||||||
|
causal=True,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.forward_metadata.extend_no_prefix:
|
||||||
|
o = o1
|
||||||
|
else:
|
||||||
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
causal=False,
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
o, _ = merge_state(o1, s1, o2, s2)
|
||||||
|
|
||||||
|
if save_kv_cache:
|
||||||
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def forward_decode(
|
def forward_decode(
|
||||||
self,
|
self,
|
||||||
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else forward_batch.encoder_out_cache_loc
|
else forward_batch.encoder_out_cache_loc
|
||||||
)
|
)
|
||||||
|
|
||||||
if k is not None:
|
if self.enable_flashinfer_mla:
|
||||||
assert v is not None
|
if k is not None:
|
||||||
if save_kv_cache:
|
assert v is not None
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
if save_kv_cache:
|
||||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
)
|
layer,
|
||||||
|
cache_loc,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
||||||
|
o = decode_wrapper.run(
|
||||||
|
reshaped_q[:, :, : layer.v_head_dim],
|
||||||
|
reshaped_q[:, :, layer.v_head_dim :],
|
||||||
|
reshaped_k[:, :, : layer.v_head_dim],
|
||||||
|
reshaped_k[:, :, layer.v_head_dim :],
|
||||||
|
)
|
||||||
|
|
||||||
o = decode_wrapper.forward(
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
else:
|
||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
if k is not None:
|
||||||
sm_scale=layer.scaling,
|
assert v is not None
|
||||||
logits_soft_cap=layer.logit_cap,
|
if save_kv_cache:
|
||||||
k_scale=layer.k_scale,
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
v_scale=layer.v_scale,
|
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||||
)
|
)
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
o = decode_wrapper.forward(
|
||||||
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
|
sm_scale=layer.scaling,
|
||||||
|
logits_soft_cap=layer.logit_cap,
|
||||||
|
k_scale=layer.k_scale,
|
||||||
|
v_scale=layer.v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def _get_wrapper_idx(self, layer: RadixAttention):
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
||||||
if self.num_wrappers == 1:
|
if self.num_wrappers == 1:
|
||||||
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[
|
||||||
|
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||||
|
],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[
|
||||||
|
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
||||||
|
],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
self,
|
self,
|
||||||
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
wrapper: Union[
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
||||||
|
],
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
wrapper.begin_forward(
|
if global_config.enable_flashinfer_mla:
|
||||||
kv_indptr,
|
sm_scale = 1.0 / math.sqrt(192)
|
||||||
kv_indices,
|
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
||||||
self.kv_last_page_len[:bs],
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
||||||
self.num_qo_heads,
|
wrapper.plan(
|
||||||
self.num_kv_heads,
|
q_indptr,
|
||||||
self.head_dim,
|
kv_indptr,
|
||||||
1,
|
kv_indices,
|
||||||
data_type=self.data_type,
|
kv_lens,
|
||||||
q_data_type=self.q_data_type,
|
self.num_qo_heads,
|
||||||
non_blocking=True,
|
512,
|
||||||
)
|
64,
|
||||||
|
1,
|
||||||
|
False,
|
||||||
|
sm_scale,
|
||||||
|
self.data_type,
|
||||||
|
self.data_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
1,
|
||||||
|
data_type=self.data_type,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
non_blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferIndicesUpdaterPrefill:
|
class FlashInferIndicesUpdaterPrefill:
|
||||||
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
|
|
||||||
# extend part
|
# extend part
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
wrapper_ragged.begin_forward(
|
if global_config.enable_flashinfer_mla:
|
||||||
qo_indptr,
|
wrapper_ragged.begin_forward(
|
||||||
|
qo_indptr=qo_indptr,
|
||||||
|
kv_indptr=qo_indptr,
|
||||||
|
num_qo_heads=self.num_qo_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_dim_qk=192,
|
||||||
|
head_dim_vo=128,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper_ragged.begin_forward(
|
||||||
|
qo_indptr,
|
||||||
|
qo_indptr,
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not global_config.enable_flashinfer_mla:
|
||||||
|
# cached part
|
||||||
|
wrapper_paged.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
self.num_qo_heads,
|
self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
1,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
|
custom_mask=custom_mask,
|
||||||
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# cached part
|
|
||||||
wrapper_paged.begin_forward(
|
|
||||||
qo_indptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_indices,
|
|
||||||
self.kv_last_page_len[:bs],
|
|
||||||
self.num_qo_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
1,
|
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
custom_mask=custom_mask,
|
|
||||||
non_blocking=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMultiStepDraftBackend:
|
class FlashInferMultiStepDraftBackend:
|
||||||
"""
|
"""
|
||||||
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
|
|||||||
window_left,
|
window_left,
|
||||||
logits_soft_cap,
|
logits_soft_cap,
|
||||||
head_dim,
|
head_dim,
|
||||||
|
head_dim,
|
||||||
empty_q_data,
|
empty_q_data,
|
||||||
empty_kv_cache,
|
empty_kv_cache,
|
||||||
stream.cuda_stream,
|
stream.cuda_stream,
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ global_server_args_dict = {
|
|||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
"device": ServerArgs.device,
|
"device": ServerArgs.device,
|
||||||
|
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
|
|||||||
monkey_patch_p2p_access_check,
|
monkey_patch_p2p_access_check,
|
||||||
monkey_patch_vllm_gguf_config,
|
monkey_patch_vllm_gguf_config,
|
||||||
set_cpu_offload_max_bytes,
|
set_cpu_offload_max_bytes,
|
||||||
|
set_cuda_arch,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -110,8 +111,14 @@ class ModelRunner:
|
|||||||
):
|
):
|
||||||
# TODO: add MLA optimization on CPU
|
# TODO: add MLA optimization on CPU
|
||||||
if self.server_args.device != "cpu":
|
if self.server_args.device != "cpu":
|
||||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
if server_args.enable_flashinfer_mla:
|
||||||
self.server_args.attention_backend = "triton"
|
logger.info(
|
||||||
|
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
|
||||||
|
)
|
||||||
|
self.server_args.attention_backend = "flashinfer"
|
||||||
|
else:
|
||||||
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||||
|
self.server_args.attention_backend = "triton"
|
||||||
|
|
||||||
if self.server_args.enable_double_sparsity:
|
if self.server_args.enable_double_sparsity:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -169,6 +176,7 @@ class ModelRunner:
|
|||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
"enable_ep_moe": server_args.enable_ep_moe,
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
"device": server_args.device,
|
"device": server_args.device,
|
||||||
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -292,6 +300,8 @@ class ModelRunner:
|
|||||||
if torch.cuda.get_device_capability()[1] < 5:
|
if torch.cuda.get_device_capability()[1] < 5:
|
||||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||||
|
|
||||||
|
set_cuda_arch()
|
||||||
|
|
||||||
# Prepare the model config
|
# Prepare the model config
|
||||||
self.load_config = LoadConfig(
|
self.load_config = LoadConfig(
|
||||||
load_format=self.server_args.load_format,
|
load_format=self.server_args.load_format,
|
||||||
|
|||||||
@@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Use normal computation for prefill and use weight absorption for extend/decode
|
if global_server_args_dict["enable_flashinfer_mla"]:
|
||||||
if (
|
if forward_batch.forward_mode.is_extend():
|
||||||
forward_batch.forward_mode.is_extend()
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||||
and forward_batch.extend_prefix_lens.sum() == 0
|
else:
|
||||||
):
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
|
||||||
else:
|
else:
|
||||||
return self.forward_absorb(positions, hidden_states, forward_batch)
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and forward_batch.extend_prefix_lens.sum() == 0
|
||||||
|
):
|
||||||
|
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||||
|
else:
|
||||||
|
return self.forward_absorb(positions, hidden_states, forward_batch)
|
||||||
|
|
||||||
def forward_normal(
|
def forward_normal(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -168,6 +168,8 @@ class ServerArgs:
|
|||||||
tool_call_parser: str = None
|
tool_call_parser: str = None
|
||||||
enable_hierarchical_cache: bool = False
|
enable_hierarchical_cache: bool = False
|
||||||
|
|
||||||
|
enable_flashinfer_mla: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set missing default values
|
# Set missing default values
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
@@ -693,6 +695,11 @@ class ServerArgs:
|
|||||||
default=ServerArgs.grammar_backend,
|
default=ServerArgs.grammar_backend,
|
||||||
help="Choose the backend for grammar-guided decoding.",
|
help="Choose the backend for grammar-guided decoding.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-flashinfer-mla",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable FlashInfer MLA optimization",
|
||||||
|
)
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port):
|
|||||||
timeout_keep_alive=5,
|
timeout_keep_alive=5,
|
||||||
loop="uvloop",
|
loop="uvloop",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_cuda_arch():
|
||||||
|
if is_flashinfer_available():
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
arch = f"{capability[0]}.{capability[1]}"
|
||||||
|
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
|
||||||
|
|||||||
@@ -4,17 +4,19 @@ set -euxo pipefail
|
|||||||
# Install the dependency in CI.
|
# Install the dependency in CI.
|
||||||
|
|
||||||
# Use repo from environment variable, passed from GitHub Actions
|
# Use repo from environment variable, passed from GitHub Actions
|
||||||
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer}"
|
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
|
||||||
|
|
||||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
bash "${SCRIPT_DIR}/killall_sglang.sh"
|
bash "${SCRIPT_DIR}/killall_sglang.sh"
|
||||||
|
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip uninstall flashinfer -y
|
pip uninstall flashinfer -y
|
||||||
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
|
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
|
||||||
|
|
||||||
|
rm -rf /root/.cache/flashinfer
|
||||||
# Force reinstall flashinfer and torch_memory_saver
|
# Force reinstall flashinfer and torch_memory_saver
|
||||||
pip install flashinfer_python==0.2.0.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
|
pip install flashinfer_python==0.2.1.post1 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
|
||||||
|
|
||||||
pip install torch_memory_saver --force-reinstall
|
pip install torch_memory_saver --force-reinstall
|
||||||
|
|
||||||
pip install transformers==4.45.2 sentence_transformers accelerate peft
|
pip install transformers==4.45.2 sentence_transformers accelerate peft
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
"speculative_eagle_topk": 8,
|
"speculative_eagle_topk": 8,
|
||||||
"speculative_num_draft_tokens": 64,
|
"speculative_num_draft_tokens": 64,
|
||||||
"mem_fraction_static": 0.7,
|
"mem_fraction_static": 0.7,
|
||||||
|
"cuda_graph_max_bs": 32,
|
||||||
}
|
}
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
"64",
|
"64",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
"0.7",
|
"0.7",
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user