diff --git a/python/pyproject.toml b/python/pyproject.toml index d6480ebb6..eac224443 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ "torchaudio==2.8.0", "torchvision", "cuda-python", - "flashinfer_python==0.4.0rc1", + "flashinfer_python==0.4.0rc3", "openai==1.99.1", "tiktoken", "anthropic>=0.20.0", diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index 6446dcd78..dd282dc5b 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -70,7 +70,7 @@ srt = [ "torchaudio==2.8.0", "torchvision", "cuda-python", - "flashinfer_python==0.4.0rc1", + "flashinfer_python==0.4.0rc3", ] blackwell = [ @@ -80,8 +80,8 @@ blackwell = [ "torchaudio==2.8.0", "torchvision", "cuda-python", - "flashinfer_python==0.4.0rc1", - "nvidia-cutlass-dsl==4.2.1", + "flashinfer_python==0.4.0rc3", + "nvidia-cutlass-dsl==4.2.0", ] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f6a0f597b..840ed332d 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.4.0rc1", + "0.4.0rc3", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 5f2b946f3..048319202 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -47,6 +47,7 @@ if is_flashinfer_available(): BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper, + fast_decode_plan, ) from flashinfer.cascade import merge_state from flashinfer.decode import _get_range_buf, get_seq_lens @@ -842,23 +843,51 @@ class FlashInferIndicesUpdaterDecode: global_override_indptr_cpu[0] = 0 global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0) - wrapper.begin_forward( - kv_indptr, - kv_indices, - self.kv_last_page_len[:bs], - self.num_qo_heads, - self.num_kv_heads, - self.head_dim, - 1, - data_type=self.data_type, - q_data_type=self.q_data_type, - non_blocking=True, - fixed_split_size=fixed_split_size, - disable_split_kv=( - disable_split_kv if disable_split_kv is not None else False - ), + # Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan + # by checking if it's a partial function with fast_decode_plan as the func + wrapper_uses_fast_decode_plan = ( + hasattr(wrapper.begin_forward, "func") + and wrapper.begin_forward.func == fast_decode_plan ) + if wrapper_uses_fast_decode_plan: + # When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu + wrapper.begin_forward( + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + fixed_split_size=fixed_split_size, + disable_split_kv=( + disable_split_kv if disable_split_kv is not None else False + ), + global_override_indptr_cpu=global_override_indptr_cpu, + ) + else: + # When using original begin_forward, don't pass global_override_indptr_cpu + wrapper.begin_forward( + kv_indptr, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + fixed_split_size=fixed_split_size, + disable_split_kv=( + disable_split_kv if disable_split_kv is not None else False + ), + ) + if locally_override: global_override_indptr_cpu = None @@ -1328,174 +1357,3 @@ def should_use_tensor_core( return gqa_group_size >= 4 else: return False - - -# Use as a fast path to override the indptr in flashinfer's plan function -# This is used to remove some host-to-device copy overhead. -global_override_indptr_cpu = None - - -def fast_decode_plan( - self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, - num_qo_heads: int, - num_kv_heads: int, - head_dim: int, - page_size: int, - pos_encoding_mode: str = "NONE", - window_left: int = -1, - logits_soft_cap: Optional[float] = None, - q_data_type: Optional[Union[str, torch.dtype]] = None, - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, - non_blocking: bool = True, - fixed_split_size: Optional[int] = None, - disable_split_kv: bool = False, -) -> None: - """ - A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend. - Modifications: - - Remove unnecessary device-to-device copy for the cuda graph buffers. - - Remove unnecessary host-to-device copy for the metadata buffers. - """ - batch_size = len(last_page_len) - if logits_soft_cap is None: - logits_soft_cap = 0.0 - - # Handle data types consistently - if data_type is not None: - if q_data_type is None: - q_data_type = data_type - if kv_data_type is None: - kv_data_type = data_type - elif q_data_type is None: - q_data_type = "float16" - - if kv_data_type is None: - kv_data_type = q_data_type - - if self.use_tensor_cores: - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") - # Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function - if fixed_split_size is None: - fixed_split_size = -1 - - if self.is_cuda_graph_enabled: - if batch_size != self._fixed_batch_size: - raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}".format( - batch_size, self._fixed_batch_size - ) - ) - if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) - else: - self._paged_kv_indptr_buf = indptr - self._paged_kv_indices_buf = indices - self._paged_kv_last_page_len_buf = last_page_len - if self.use_tensor_cores: - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking - ) - - # Create empty tensors for dtype info if needed - empty_q_data = torch.empty( - 0, - dtype=( - getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type - ), - device=self.device, - ) - - empty_kv_cache = torch.empty( - 0, - dtype=( - getattr(torch, kv_data_type) - if isinstance(kv_data_type, str) - else kv_data_type - ), - device=self.device, - ) - - indptr_host = ( - global_override_indptr_cpu - if global_override_indptr_cpu is not None - else indptr.cpu() - ) - - with torch.cuda.device(self.device): - - if self.use_tensor_cores: - # ALSO convert last_page_len to CPU - if page_size == 1: - # When page size is 1, last_page_len is always 1. - # Directly construct the host tensor rather than executing a device-to-host copy. - last_page_len_host = torch.ones( - (batch_size,), dtype=torch.int32, device="cpu" - ) - else: - last_page_len_host = last_page_len.cpu() - - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) - - try: - # Make sure we pass exactly 15 arguments for tensor core version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal - window_left, - fixed_split_size, - disable_split_kv, - ) - except Exception as e: - raise RuntimeError(f"Error in standard plan: {e}") - else: - try: - # Make sure we pass exactly 15 arguments for standard version - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - indptr_host, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - window_left, - logits_soft_cap, - head_dim, - head_dim, - empty_q_data, - empty_kv_cache, - ) - except Exception as e: - raise RuntimeError(f"Error in standard plan: {e}") - - self._pos_encoding_mode = pos_encoding_mode - self._window_left = window_left - self._logits_soft_cap = logits_soft_cap - self._sm_scale = sm_scale - self._rope_scale = rope_scale - self._rope_theta = rope_theta