feat: support flashinfer mla attention for deepseek v3 (#3550)

This commit is contained in:
Yineng Zhang
2025-02-14 08:50:14 +08:00
committed by GitHub
parent 368de3661e
commit 70f894b810
12 changed files with 299 additions and 135 deletions

View File

@@ -72,7 +72,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
@@ -98,7 +98,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
@@ -123,7 +123,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
@@ -163,7 +163,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
@@ -209,7 +209,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
@@ -243,7 +243,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
@@ -283,7 +283,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
git clone https://github.com/merrymercy/human-eval.git git clone https://github.com/merrymercy/human-eval.git
@@ -308,7 +308,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
env: env:
FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }}
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
git clone https://github.com/merrymercy/human-eval.git git clone https://github.com/merrymercy/human-eval.git

View File

@@ -21,12 +21,13 @@ runtime_common = [
"hf_transfer", "huggingface_hub", "interegular", "modelscope", "hf_transfer", "huggingface_hub", "interegular", "modelscope",
"orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "orjson", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2",
"torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10" "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10", "ninja"
] ]
srt = [ srt = [
"sglang[runtime_common]", "cuda-python", "sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2", "sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2",
"flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11" "flashinfer_python>=0.2.1.post1",
"outlines>=0.0.44,<=0.1.11",
] ]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD

View File

@@ -38,5 +38,7 @@ class GlobalConfig:
self.enable_precache_with_tracing = True self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True self.enable_parallel_encoding = True
self.enable_flashinfer_mla = False
global_config = GlobalConfig() global_config = GlobalConfig()

View File

@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
"flashinfer_python", "flashinfer_python",
"0.2.0.post2", "0.2.1.post1",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",

View File

@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
""" """
import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
@@ -20,6 +21,7 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
@@ -35,7 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode from flashinfer.mla import BatchMLAPagedAttentionWrapper
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
@dataclass @dataclass
class DecodeMetadata: class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
]
@dataclass @dataclass
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
global_config.flashinfer_workspace_size = 512 * 1024 * 1024 global_config.flashinfer_workspace_size = 512 * 1024 * 1024
self.enable_flashinfer_mla = False
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
if global_server_args_dict["enable_flashinfer_mla"]:
self.enable_flashinfer_mla = True
global_config.enable_flashinfer_mla = True
# Allocate buffers # Allocate buffers
global global_workspace_buffer global global_workspace_buffer
if global_workspace_buffer is None: if global_workspace_buffer is None:
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
) )
for _ in range(self.num_wrappers) for _ in range(self.num_wrappers)
] ]
if self.enable_flashinfer_mla:
self.qo_indptr = [
torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
else: else:
assert self.num_wrappers == 1 assert self.num_wrappers == 1
self.kv_indptr = [kv_indptr_buf] self.kv_indptr = [kv_indptr_buf]
@@ -153,6 +170,11 @@ class FlashInferAttnBackend(AttentionBackend):
self.prefill_wrappers_verify.append( self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
) )
if self.enable_flashinfer_mla:
self.decode_wrappers.append(
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
)
else:
self.decode_wrappers.append( self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
@@ -274,6 +296,19 @@ class FlashInferAttnBackend(AttentionBackend):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
for i in range(self.num_wrappers): for i in range(self.num_wrappers):
if self.enable_flashinfer_mla:
decode_wrappers.append(
BatchMLAPagedAttentionWrapper(
self.workspace_buffer,
use_cuda_graph=True,
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
kv_indices=self.cuda_graph_kv_indices[i],
kv_len_arr=self.kv_last_page_len[:num_tokens],
backend="fa2",
)
)
else:
decode_wrappers.append( decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
@@ -375,6 +410,36 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
if global_config.enable_flashinfer_mla:
cache_loc = (
forward_batch.out_cache_loc
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
logits_soft_cap = layer.logit_cap
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
k.view(-1, layer.tp_k_head_num, layer.head_dim),
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
causal=True,
sm_scale=layer.scaling,
logits_soft_cap=logits_soft_cap,
)
o = o1
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
self._get_wrapper_idx(layer) self._get_wrapper_idx(layer)
] ]
@@ -452,6 +517,28 @@ class FlashInferAttnBackend(AttentionBackend):
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if self.enable_flashinfer_mla:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
reshaped_k[:, :, : layer.v_head_dim],
reshaped_k[:, :, layer.v_head_dim :],
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
def call_begin_forward( def call_begin_forward(
self, self,
wrapper: BatchDecodeWithPagedKVCacheWrapper, wrapper: Union[
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
],
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
@@ -637,6 +730,25 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
if global_config.enable_flashinfer_mla:
sm_scale = 1.0 / math.sqrt(192)
q_indptr = torch.arange(0, bs + 1).to(0).int()
kv_lens = paged_kernel_lens.to(torch.int32)
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
self.num_qo_heads,
512,
64,
1,
False,
sm_scale,
self.data_type,
self.data_type,
)
else:
wrapper.begin_forward( wrapper.begin_forward(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
@@ -857,6 +969,17 @@ class FlashInferIndicesUpdaterPrefill:
# extend part # extend part
if use_ragged: if use_ragged:
if global_config.enable_flashinfer_mla:
wrapper_ragged.begin_forward(
qo_indptr=qo_indptr,
kv_indptr=qo_indptr,
num_qo_heads=self.num_qo_heads,
num_kv_heads=self.num_kv_heads,
head_dim_qk=192,
head_dim_vo=128,
q_data_type=self.q_data_type,
)
else:
wrapper_ragged.begin_forward( wrapper_ragged.begin_forward(
qo_indptr, qo_indptr,
qo_indptr, qo_indptr,
@@ -866,6 +989,7 @@ class FlashInferIndicesUpdaterPrefill:
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
) )
if not global_config.enable_flashinfer_mla:
# cached part # cached part
wrapper_paged.begin_forward( wrapper_paged.begin_forward(
qo_indptr, qo_indptr,
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
window_left, window_left,
logits_soft_cap, logits_soft_cap,
head_dim, head_dim,
head_dim,
empty_q_data, empty_q_data,
empty_kv_cache, empty_kv_cache,
stream.cuda_stream, stream.cuda_stream,

View File

@@ -65,6 +65,7 @@ global_server_args_dict = {
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe, "enable_ep_moe": ServerArgs.enable_ep_moe,
"device": ServerArgs.device, "device": ServerArgs.device,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
} }
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -67,6 +67,7 @@ from sglang.srt.utils import (
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
set_cuda_arch,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -110,6 +111,12 @@ class ModelRunner:
): ):
# TODO: add MLA optimization on CPU # TODO: add MLA optimization on CPU
if self.server_args.device != "cpu": if self.server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
logger.info(
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
)
self.server_args.attention_backend = "flashinfer"
else:
logger.info("MLA optimization is turned on. Use triton backend.") logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton" self.server_args.attention_backend = "triton"
@@ -169,6 +176,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe, "enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device, "device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
} }
) )
@@ -292,6 +300,8 @@ class ModelRunner:
if torch.cuda.get_device_capability()[1] < 5: if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.") raise RuntimeError("SGLang only supports sm75 and above.")
set_cuda_arch()
# Prepare the model config # Prepare the model config
self.load_config = LoadConfig( self.load_config = LoadConfig(
load_format=self.server_args.load_format, load_format=self.server_args.load_format,

View File

@@ -510,7 +510,13 @@ class DeepseekV2AttentionMLA(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
# Use normal computation for prefill and use weight absorption for extend/decode if global_server_args_dict["enable_flashinfer_mla"]:
if forward_batch.forward_mode.is_extend():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
return self.forward_absorb(positions, hidden_states, forward_batch)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and forward_batch.extend_prefix_lens.sum() == 0 and forward_batch.extend_prefix_lens.sum() == 0

View File

@@ -168,6 +168,8 @@ class ServerArgs:
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
enable_flashinfer_mla: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
@@ -693,6 +695,11 @@ class ServerArgs:
default=ServerArgs.grammar_backend, default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.", help="Choose the backend for grammar-guided decoding.",
) )
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization",
)
# Speculative decoding # Speculative decoding
parser.add_argument( parser.add_argument(

View File

@@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port):
timeout_keep_alive=5, timeout_keep_alive=5,
loop="uvloop", loop="uvloop",
) )
def set_cuda_arch():
if is_flashinfer_available():
capability = torch.cuda.get_device_capability()
arch = f"{capability[0]}.{capability[1]}"
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"

View File

@@ -4,17 +4,19 @@ set -euxo pipefail
# Install the dependency in CI. # Install the dependency in CI.
# Use repo from environment variable, passed from GitHub Actions # Use repo from environment variable, passed from GitHub Actions
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer}" FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
bash "${SCRIPT_DIR}/killall_sglang.sh" bash "${SCRIPT_DIR}/killall_sglang.sh"
pip install --upgrade pip pip install --upgrade pip
pip uninstall flashinfer -y pip uninstall flashinfer -y
pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
rm -rf /root/.cache/flashinfer
# Force reinstall flashinfer and torch_memory_saver # Force reinstall flashinfer and torch_memory_saver
pip install flashinfer_python==0.2.0.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install flashinfer_python==0.2.1.post1 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps
pip install torch_memory_saver --force-reinstall pip install torch_memory_saver --force-reinstall
pip install transformers==4.45.2 sentence_transformers accelerate peft pip install transformers==4.45.2 sentence_transformers accelerate peft

View File

@@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_eagle_topk": 8, "speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64, "speculative_num_draft_tokens": 64,
"mem_fraction_static": 0.7, "mem_fraction_static": 0.7,
"cuda_graph_max_bs": 32,
} }
def setUp(self): def setUp(self):
@@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase):
"64", "64",
"--mem-fraction-static", "--mem-fraction-static",
"0.7", "0.7",
"--cuda-graph-max-bs",
"32",
], ],
) )