diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 851555aea..d22bc9a88 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -96,8 +96,6 @@ jobs: uses: actions/checkout@v4 - name: Install dependencies - env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} run: | bash scripts/ci_install_dependency.sh diff --git a/docs/start/install.md b/docs/start/install.md index f3e4ef97b..9c8435bd2 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -164,4 +164,4 @@ sky status --endpoint 30000 sglang - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub. - If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime. -- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.3" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`. +- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python==0.2.5" -i https://flashinfer.ai/whl/cu124/torch2.6 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`. diff --git a/python/pyproject.toml b/python/pyproject.toml index cf2bed1d3..0dadcb0a9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -37,7 +37,7 @@ runtime_common = [ "python-multipart", "pyzmq>=25.1.2", "soundfile==0.13.1", - "torchao>=0.7.0", + "torchao>=0.9.0", "transformers==4.51.1", "uvicorn", "uvloop", @@ -47,7 +47,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", "sgl-kernel==0.1.0", - "flashinfer_python==0.2.3", + "flashinfer_python==0.2.5", "torch==2.6.0", "torchvision==0.21.0", "cuda-python", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 3e8222f13..f26437db9 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.2.3", + "0.2.5", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 8586006dc..b590790be 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union import torch +if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": + import torch._dynamo + + torch._dynamo.config.suppress_errors = True + from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton @@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend): self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill self.is_multimodal = model_runner.model_config.is_multimodal - self.kv_cache_dtype = model_runner.kv_cache_dtype - self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype assert not ( model_runner.sliding_window_size is not None @@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend): cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] + # Ensure tensors are properly allocated + for i in range(self.num_wrappers): + # Force allocation by performing a small operation + if len(self.cuda_graph_kv_indices[i]) > 0: + self.cuda_graph_kv_indices[i][0] = 0 + if not self.skip_prefill: self.cuda_graph_custom_mask = torch.zeros( (max_bs * self.max_context_len), @@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None - v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ self._get_wrapper_idx(layer) ] @@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend): assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, k_scale, v_scale + layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) o = prefill_wrapper_paged.forward( @@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend): sm_scale=layer.scaling, window_left=layer.sliding_window_size, logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( @@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend): if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, k_scale, v_scale + layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch: ForwardBatch, save_kv_cache=True, ): - k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None - v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None decode_wrapper = self.forward_metadata.decode_wrappers[ self._get_wrapper_idx(layer) ] @@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend): assert v is not None if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( - layer, cache_loc, k, v, k_scale, v_scale + layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + # Call the wrapped function o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, - k_scale=k_scale, - v_scale=v_scale, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -1146,8 +1152,9 @@ def fast_decode_plan( pos_encoding_mode: str = "NONE", window_left: int = -1, logits_soft_cap: Optional[float] = None, - data_type: Union[str, torch.dtype] = "float16", q_data_type: Optional[Union[str, torch.dtype]] = None, + kv_data_type: Optional[Union[str, torch.dtype]] = None, + data_type: Optional[Union[str, torch.dtype]] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, @@ -1163,6 +1170,18 @@ def fast_decode_plan( if logits_soft_cap is None: logits_soft_cap = 0.0 + # Handle data types consistently + if data_type is not None: + if q_data_type is None: + q_data_type = data_type + if kv_data_type is None: + kv_data_type = data_type + elif q_data_type is None: + q_data_type = "float16" + + if kv_data_type is None: + kv_data_type = q_data_type + if self.use_tensor_cores: qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") @@ -1178,36 +1197,33 @@ def fast_decode_plan( raise ValueError( "The size of indices should be less than or equal to the allocated buffer" ) - # Skip these copies because we directly write to them during prepartion - # self._paged_kv_indptr_buf.copy_(indptr) - # self._paged_kv_indices_buf[: len(indices)] = indices - # self._paged_kv_last_page_len_buf.copy_(last_page_len) else: self._paged_kv_indptr_buf = indptr self._paged_kv_indices_buf = indices self._paged_kv_last_page_len_buf = last_page_len - self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking) + if self.use_tensor_cores: + self._qo_indptr_buf = qo_indptr_host.to( + self.device, non_blocking=non_blocking + ) - # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info - if not q_data_type: - q_data_type = data_type + # Create empty tensors for dtype info if needed + empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ), + device=self.device, + ) - if not hasattr(self, "empty_q_data"): - self.empty_q_data = torch.empty( - 0, - dtype=( - getattr(torch, q_data_type) - if isinstance(q_data_type, str) - else q_data_type - ), - ) - self.empty_kv_cache = torch.empty( - 0, - dtype=( - getattr(torch, data_type) if isinstance(data_type, str) else data_type - ), - ) - self.last_page_len = torch.ones(32768, dtype=torch.int32) + empty_kv_cache = torch.empty( + 0, + dtype=( + getattr(torch, kv_data_type) + if isinstance(kv_data_type, str) + else kv_data_type + ), + device=self.device, + ) indptr_host = ( global_override_indptr_cpu @@ -1215,48 +1231,57 @@ def fast_decode_plan( else indptr.cpu() ) - if self.use_tensor_cores: - kv_lens_arr_host = get_seq_lens( - indptr_host, self.last_page_len[:batch_size], page_size - ) + with torch.cuda.device(self.device): - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - torch.cuda.current_stream().cuda_stream, - ) - else: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - self.empty_q_data, - self.empty_kv_cache, - torch.cuda.current_stream().cuda_stream, - ) + if self.use_tensor_cores: + # ALSO convert last_page_len to CPU + last_page_len_host = last_page_len.cpu() + + kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + + try: + # Make sure we pass exactly 15 arguments for tensor core version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + indptr_host, + kv_lens_arr_host, + batch_size, # total_num_rows + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + head_dim, + head_dim, + False, # causal + ) + except Exception as e: + raise RuntimeError(f"Error in standard plan: {e}") + else: + try: + # Make sure we pass exactly 15 arguments for standard version + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr_host, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + window_left, + logits_soft_cap, + head_dim, + head_dim, + empty_q_data, + empty_kv_cache, + ) + except Exception as e: + raise RuntimeError(f"Error in standard plan: {e}") self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 81afcb9da..485ed5f2c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding. More details can be found in https://docs.flashinfer.ai/api/mla.html """ +import os from dataclasses import dataclass from functools import partial from typing import TYPE_CHECKING, Callable, Optional, Union @@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union import torch import triton +if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": + import torch._dynamo + + torch._dynamo.config.suppress_errors = True + from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.flashinfer_backend import ( @@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend): k, v, ) + + # Reshape inputs reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - reshaped_k = k_buffer.view(-1, 1, layer.head_dim) + + # Direct call to run without the wrapper o = decode_wrapper.run( reshaped_q[:, :, : layer.v_head_dim], reshaped_q[:, :, layer.v_head_dim :], - reshaped_k[:, :, : layer.v_head_dim], - reshaped_k[:, :, layer.v_head_dim :], + k_buffer[:, :, : layer.v_head_dim], + k_buffer[:, :, layer.v_head_dim :], ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) @@ -825,16 +834,18 @@ def fast_mla_decode_plan( self._sm_scale = sm_scale with self.device as device: - stream = torch.cuda.current_stream(device).cuda_stream - self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_cpu, - kv_indptr_cpu, - kv_len_arr_cpu, - num_heads, - head_dim_ckv, - causal, - stream, - ) + try: + # Standard version with just the required arguments (no use_profiler) + self._cached_module.plan.default( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_cpu, + kv_indptr_cpu, + kv_len_arr_cpu, + num_heads, + head_dim_ckv, + causal, + ) + except Exception as e: + raise RuntimeError(f"Error in alternate MLA plan: {e}")