diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 1fd007e86..81d4349ee 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -72,7 +72,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh @@ -98,7 +98,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh @@ -123,7 +123,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh @@ -163,7 +163,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh @@ -209,7 +209,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh @@ -243,7 +243,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh @@ -283,7 +283,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh git clone https://github.com/merrymercy/human-eval.git @@ -308,7 +308,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh git clone https://github.com/merrymercy/human-eval.git diff --git a/python/pyproject.toml b/python/pyproject.toml index 34ada35dc..896c7a41e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,12 +21,13 @@ runtime_common = [ "hf_transfer", "huggingface_hub", "interegular", "modelscope", "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", - "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10" + "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10", "ninja" ] srt = [ "sglang[runtime_common]", "cuda-python", "sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2", - "flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11" + "flashinfer_python>=0.2.1.post1", + "outlines>=0.0.44,<=0.1.11", ] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index d557e6a6e..ac034ec0a 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -38,5 +38,7 @@ class GlobalConfig: self.enable_precache_with_tracing = True self.enable_parallel_encoding = True + self.enable_flashinfer_mla = False + global_config = GlobalConfig() diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 942a53c37..b0e780706 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.2.0.post2", + "0.2.1.post1", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 99708135a..2c4c6c65b 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ +import math import os from dataclasses import dataclass from enum import Enum, auto @@ -20,6 +21,7 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -35,7 +37,7 @@ if is_flashinfer_available(): BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state - from flashinfer.decode import PosEncodingMode + from flashinfer.mla import BatchMLAPagedAttentionWrapper class WrapperDispatch(Enum): @@ -45,7 +47,9 @@ class WrapperDispatch(Enum): @dataclass class DecodeMetadata: - decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] + decode_wrappers: List[ + Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] + ] @dataclass @@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend): if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: global_config.flashinfer_workspace_size = 512 * 1024 * 1024 + self.enable_flashinfer_mla = False + if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures: + if global_server_args_dict["enable_flashinfer_mla"]: + self.enable_flashinfer_mla = True + global_config.enable_flashinfer_mla = True + # Allocate buffers global global_workspace_buffer if global_workspace_buffer is None: @@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend): ) for _ in range(self.num_wrappers) ] + if self.enable_flashinfer_mla: + self.qo_indptr = [ + torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + for _ in range(self.num_wrappers) + ] else: assert self.num_wrappers == 1 self.kv_indptr = [kv_indptr_buf] @@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend): self.prefill_wrappers_verify.append( BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") ) - self.decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_tensor_cores=self.decode_use_tensor_cores, + if self.enable_flashinfer_mla: + self.decode_wrappers.append( + BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2") + ) + else: + self.decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) ) - ) # Create indices updater if not skip_prefill: @@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend): if forward_mode.is_decode_or_idle(): decode_wrappers = [] for i in range(self.num_wrappers): - decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.kv_last_page_len[ - :num_tokens - ], + if self.enable_flashinfer_mla: + decode_wrappers.append( + BatchMLAPagedAttentionWrapper( + self.workspace_buffer, + use_cuda_graph=True, + qo_indptr=self.qo_indptr[i][: num_tokens + 1], + kv_indptr=self.kv_indptr[i][: num_tokens + 1], + kv_indices=self.cuda_graph_kv_indices[i], + kv_len_arr=self.kv_last_page_len[:num_tokens], + backend="fa2", + ) + ) + else: + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[ + :num_tokens + ], + ) ) - ) seq_lens_sum = seq_lens.sum().item() self.indices_updater_decode.update( req_pool_indices, @@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ - self._get_wrapper_idx(layer) - ] - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) - - logits_soft_cap = layer.logit_cap - - if not self.forward_metadata.use_ragged: - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) - - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, - sm_scale=layer.scaling, - window_left=layer.sliding_window_size, - logits_soft_cap=logits_soft_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, + if global_config.enable_flashinfer_mla: + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc ) - else: - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + + logits_soft_cap = layer.logit_cap + + o1, _ = self.prefill_wrapper_ragged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), - v.view(-1, layer.tp_v_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.v_head_dim), causal=True, sm_scale=layer.scaling, logits_soft_cap=logits_soft_cap, ) - if self.forward_metadata.extend_no_prefix: - o = o1 - else: - o2, s2 = prefill_wrapper_paged.forward_return_lse( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=False, - sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, - ) - - o, _ = merge_state(o1, s1, o2, s2) + o = o1 if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale + layer, + cache_loc, + k, + v, ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + else: + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ + self._get_wrapper_idx(layer) + ] + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + + logits_soft_cap = layer.logit_cap + + if not self.forward_metadata.use_ragged: + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=logits_soft_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + ) + else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.view(-1, layer.tp_q_head_num, layer.head_dim), + k.view(-1, layer.tp_k_head_num, layer.head_dim), + v.view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=logits_soft_cap, + ) + + if self.forward_metadata.extend_no_prefix: + o = o1 + else: + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + o, _ = merge_state(o1, s1, o2, s2) + + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( self, @@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend): else forward_batch.encoder_out_cache_loc ) - if k is not None: - assert v is not None - if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, layer.k_scale, layer.v_scale - ) + if self.enable_flashinfer_mla: + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) + reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + reshaped_k = k_buffer.view(-1, 1, layer.head_dim) + o = decode_wrapper.run( + reshaped_q[:, :, : layer.v_head_dim], + reshaped_q[:, :, layer.v_head_dim :], + reshaped_k[:, :, : layer.v_head_dim], + reshaped_k[:, :, layer.v_head_dim :], + ) - o = decode_wrapper.forward( - q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), - forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, - k_scale=layer.k_scale, - v_scale=layer.v_scale, - ) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + else: + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) - return o.view(-1, layer.tp_q_head_num * layer.head_dim) + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) def _get_wrapper_idx(self, layer: RadixAttention): if self.num_wrappers == 1: @@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], + decode_wrappers: List[ + Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] + ], encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): @@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode: req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], + decode_wrappers: List[ + Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper] + ], encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInfo], ): @@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode: def call_begin_forward( self, - wrapper: BatchDecodeWithPagedKVCacheWrapper, + wrapper: Union[ + BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper + ], req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode: kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices bs = kv_indptr.shape[0] - 1 - wrapper.begin_forward( - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - data_type=self.data_type, - q_data_type=self.q_data_type, - non_blocking=True, - ) + if global_config.enable_flashinfer_mla: + sm_scale = 1.0 / math.sqrt(192) + q_indptr = torch.arange(0, bs + 1).to(0).int() + kv_lens = paged_kernel_lens.to(torch.int32) + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_lens, + self.num_qo_heads, + 512, + 64, + 1, + False, + sm_scale, + self.data_type, + self.data_type, + ) + else: + wrapper.begin_forward( + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + ) class FlashInferIndicesUpdaterPrefill: @@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill: # extend part if use_ragged: - wrapper_ragged.begin_forward( - qo_indptr, + if global_config.enable_flashinfer_mla: + wrapper_ragged.begin_forward( + qo_indptr=qo_indptr, + kv_indptr=qo_indptr, + num_qo_heads=self.num_qo_heads, + num_kv_heads=self.num_kv_heads, + head_dim_qk=192, + head_dim_vo=128, + q_data_type=self.q_data_type, + ) + else: + wrapper_ragged.begin_forward( + qo_indptr, + qo_indptr, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + q_data_type=self.q_data_type, + ) + + if not global_config.enable_flashinfer_mla: + # cached part + wrapper_paged.begin_forward( qo_indptr, + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], self.num_qo_heads, self.num_kv_heads, self.head_dim, + 1, q_data_type=self.q_data_type, + custom_mask=custom_mask, + non_blocking=True, ) - # cached part - wrapper_paged.begin_forward( - qo_indptr, - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - q_data_type=self.q_data_type, - custom_mask=custom_mask, - non_blocking=True, - ) - class FlashInferMultiStepDraftBackend: """ @@ -1163,6 +1287,7 @@ def fast_decode_plan( window_left, logits_soft_cap, head_dim, + head_dim, empty_q_data, empty_kv_cache, stream.cuda_stream, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ecac38656..8ff0ff7e7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -65,6 +65,7 @@ global_server_args_dict = { "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_ep_moe": ServerArgs.enable_ep_moe, "device": ServerArgs.device, + "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, } logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d125868b0..3242c0d61 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -67,6 +67,7 @@ from sglang.srt.utils import ( monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, set_cpu_offload_max_bytes, + set_cuda_arch, ) logger = logging.getLogger(__name__) @@ -110,8 +111,14 @@ class ModelRunner: ): # TODO: add MLA optimization on CPU if self.server_args.device != "cpu": - logger.info("MLA optimization is turned on. Use triton backend.") - self.server_args.attention_backend = "triton" + if server_args.enable_flashinfer_mla: + logger.info( + "FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM." + ) + self.server_args.attention_backend = "flashinfer" + else: + logger.info("MLA optimization is turned on. Use triton backend.") + self.server_args.attention_backend = "triton" if self.server_args.enable_double_sparsity: logger.info( @@ -169,6 +176,7 @@ class ModelRunner: "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, "device": server_args.device, + "enable_flashinfer_mla": server_args.enable_flashinfer_mla, } ) @@ -292,6 +300,8 @@ class ModelRunner: if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") + set_cuda_arch() + # Prepare the model config self.load_config = LoadConfig( load_format=self.server_args.load_format, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2a1c75cc4..9046f227d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -510,14 +510,20 @@ class DeepseekV2AttentionMLA(nn.Module): hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - # Use normal computation for prefill and use weight absorption for extend/decode - if ( - forward_batch.forward_mode.is_extend() - and forward_batch.extend_prefix_lens.sum() == 0 - ): - return self.forward_normal(positions, hidden_states, forward_batch) + if global_server_args_dict["enable_flashinfer_mla"]: + if forward_batch.forward_mode.is_extend(): + return self.forward_normal(positions, hidden_states, forward_batch) + else: + return self.forward_absorb(positions, hidden_states, forward_batch) else: - return self.forward_absorb(positions, hidden_states, forward_batch) + # Triton: Use normal computation for prefill and use weight absorption for extend/decode + if ( + forward_batch.forward_mode.is_extend() + and forward_batch.extend_prefix_lens.sum() == 0 + ): + return self.forward_normal(positions, hidden_states, forward_batch) + else: + return self.forward_absorb(positions, hidden_states, forward_batch) def forward_normal( self, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 93f797087..a8ab27cc9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -168,6 +168,8 @@ class ServerArgs: tool_call_parser: str = None enable_hierarchical_cache: bool = False + enable_flashinfer_mla: bool = False + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -693,6 +695,11 @@ class ServerArgs: default=ServerArgs.grammar_backend, help="Choose the backend for grammar-guided decoding.", ) + parser.add_argument( + "--enable-flashinfer-mla", + action="store_true", + help="Enable FlashInfer MLA optimization", + ) # Speculative decoding parser.add_argument( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b1c49f527..68ad77846 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1444,3 +1444,10 @@ def launch_dummy_health_check_server(host, port): timeout_keep_alive=5, loop="uvloop", ) + + +def set_cuda_arch(): + if is_flashinfer_available(): + capability = torch.cuda.get_device_capability() + arch = f"{capability[0]}.{capability[1]}" + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index cce3042f4..7cd6acd74 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -4,17 +4,19 @@ set -euxo pipefail # Install the dependency in CI. # Use repo from environment variable, passed from GitHub Actions -FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer}" +FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" bash "${SCRIPT_DIR}/killall_sglang.sh" pip install --upgrade pip pip uninstall flashinfer -y -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python +rm -rf /root/.cache/flashinfer # Force reinstall flashinfer and torch_memory_saver -pip install flashinfer_python==0.2.0.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install flashinfer_python==0.2.1.post1 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps + pip install torch_memory_saver --force-reinstall pip install transformers==4.45.2 sentence_transformers accelerate peft diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 5e627fd11..f2d2bae70 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -28,6 +28,7 @@ class TestEAGLEEngine(unittest.TestCase): "speculative_eagle_topk": 8, "speculative_num_draft_tokens": 64, "mem_fraction_static": 0.7, + "cuda_graph_max_bs": 32, } def setUp(self): @@ -124,6 +125,8 @@ class TestEAGLEServer(unittest.TestCase): "64", "--mem-fraction-static", "0.7", + "--cuda-graph-max-bs", + "32", ], )