Bump Flashinfer to 0.2.5 (#5870)
Co-authored-by: Yuhao Chen <yxckeis8@gmail.com>
This commit is contained in:
@@ -37,7 +37,7 @@ runtime_common = [
|
||||
"python-multipart",
|
||||
"pyzmq>=25.1.2",
|
||||
"soundfile==0.13.1",
|
||||
"torchao>=0.7.0",
|
||||
"torchao>=0.9.0",
|
||||
"transformers==4.51.1",
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
@@ -47,7 +47,7 @@ runtime_common = [
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.1.0",
|
||||
"flashinfer_python==0.2.3",
|
||||
"flashinfer_python==0.2.5",
|
||||
"torch==2.6.0",
|
||||
"torchvision==0.21.0",
|
||||
"cuda-python",
|
||||
|
||||
@@ -453,7 +453,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
assert_pkg_version(
|
||||
"flashinfer_python",
|
||||
"0.2.3",
|
||||
"0.2.5",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
import torch._dynamo
|
||||
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
@@ -82,8 +87,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.skip_prefill = skip_prefill
|
||||
self.is_multimodal = model_runner.model_config.is_multimodal
|
||||
self.kv_cache_dtype = model_runner.kv_cache_dtype
|
||||
self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
|
||||
|
||||
assert not (
|
||||
model_runner.sliding_window_size is not None
|
||||
@@ -268,6 +271,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||
]
|
||||
|
||||
# Ensure tensors are properly allocated
|
||||
for i in range(self.num_wrappers):
|
||||
# Force allocation by performing a small operation
|
||||
if len(self.cuda_graph_kv_indices[i]) > 0:
|
||||
self.cuda_graph_kv_indices[i][0] = 0
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
@@ -396,8 +405,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
@@ -414,7 +421,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
o = prefill_wrapper_paged.forward(
|
||||
@@ -424,8 +431,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
else:
|
||||
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
||||
@@ -452,7 +459,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
@@ -466,8 +473,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
k_scale = layer.k_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
v_scale = layer.v_scale_float if self.kv_cache_dtype_str != "auto" else None
|
||||
decode_wrapper = self.forward_metadata.decode_wrappers[
|
||||
self._get_wrapper_idx(layer)
|
||||
]
|
||||
@@ -481,16 +486,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert v is not None
|
||||
if save_kv_cache:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, k_scale, v_scale
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
# Call the wrapped function
|
||||
o = decode_wrapper.forward(
|
||||
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
sm_scale=layer.scaling,
|
||||
logits_soft_cap=layer.logit_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
k_scale=layer.k_scale,
|
||||
v_scale=layer.v_scale,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
@@ -1146,8 +1152,9 @@ def fast_decode_plan(
|
||||
pos_encoding_mode: str = "NONE",
|
||||
window_left: int = -1,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
data_type: Union[str, torch.dtype] = "float16",
|
||||
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
data_type: Optional[Union[str, torch.dtype]] = None,
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
@@ -1163,6 +1170,18 @@ def fast_decode_plan(
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
# Handle data types consistently
|
||||
if data_type is not None:
|
||||
if q_data_type is None:
|
||||
q_data_type = data_type
|
||||
if kv_data_type is None:
|
||||
kv_data_type = data_type
|
||||
elif q_data_type is None:
|
||||
q_data_type = "float16"
|
||||
|
||||
if kv_data_type is None:
|
||||
kv_data_type = q_data_type
|
||||
|
||||
if self.use_tensor_cores:
|
||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
||||
|
||||
@@ -1178,36 +1197,33 @@ def fast_decode_plan(
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
# Skip these copies because we directly write to them during prepartion
|
||||
# self._paged_kv_indptr_buf.copy_(indptr)
|
||||
# self._paged_kv_indices_buf[: len(indices)] = indices
|
||||
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
||||
else:
|
||||
self._paged_kv_indptr_buf = indptr
|
||||
self._paged_kv_indices_buf = indices
|
||||
self._paged_kv_last_page_len_buf = last_page_len
|
||||
self._qo_indptr_buf = qo_indptr_host.to(self.device, non_blocking=non_blocking)
|
||||
if self.use_tensor_cores:
|
||||
self._qo_indptr_buf = qo_indptr_host.to(
|
||||
self.device, non_blocking=non_blocking
|
||||
)
|
||||
|
||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||
if not q_data_type:
|
||||
q_data_type = data_type
|
||||
# Create empty tensors for dtype info if needed
|
||||
empty_q_data = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
||||
),
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if not hasattr(self, "empty_q_data"):
|
||||
self.empty_q_data = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, q_data_type)
|
||||
if isinstance(q_data_type, str)
|
||||
else q_data_type
|
||||
),
|
||||
)
|
||||
self.empty_kv_cache = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, data_type) if isinstance(data_type, str) else data_type
|
||||
),
|
||||
)
|
||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||
empty_kv_cache = torch.empty(
|
||||
0,
|
||||
dtype=(
|
||||
getattr(torch, kv_data_type)
|
||||
if isinstance(kv_data_type, str)
|
||||
else kv_data_type
|
||||
),
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
indptr_host = (
|
||||
global_override_indptr_cpu
|
||||
@@ -1215,48 +1231,57 @@ def fast_decode_plan(
|
||||
else indptr.cpu()
|
||||
)
|
||||
|
||||
if self.use_tensor_cores:
|
||||
kv_lens_arr_host = get_seq_lens(
|
||||
indptr_host, self.last_page_len[:batch_size], page_size
|
||||
)
|
||||
with torch.cuda.device(self.device):
|
||||
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_host,
|
||||
kv_lens_arr_host,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
else:
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_host,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
self.empty_q_data,
|
||||
self.empty_kv_cache,
|
||||
torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
if self.use_tensor_cores:
|
||||
# ALSO convert last_page_len to CPU
|
||||
last_page_len_host = last_page_len.cpu()
|
||||
|
||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
||||
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for tensor core version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_host,
|
||||
indptr_host,
|
||||
kv_lens_arr_host,
|
||||
batch_size, # total_num_rows
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
head_dim,
|
||||
head_dim,
|
||||
False, # causal
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}")
|
||||
else:
|
||||
try:
|
||||
# Make sure we pass exactly 15 arguments for standard version
|
||||
self._plan_info = self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
indptr_host,
|
||||
batch_size,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
page_size,
|
||||
self.is_cuda_graph_enabled,
|
||||
window_left,
|
||||
logits_soft_cap,
|
||||
head_dim,
|
||||
head_dim,
|
||||
empty_q_data,
|
||||
empty_kv_cache,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in standard plan: {e}")
|
||||
|
||||
self._pos_encoding_mode = pos_encoding_mode
|
||||
self._window_left = window_left
|
||||
|
||||
@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
|
||||
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
@@ -16,6 +17,11 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
import torch
|
||||
import triton
|
||||
|
||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
import torch._dynamo
|
||||
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
@@ -388,14 +394,17 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
k,
|
||||
v,
|
||||
)
|
||||
|
||||
# Reshape inputs
|
||||
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
||||
|
||||
# Direct call to run without the wrapper
|
||||
o = decode_wrapper.run(
|
||||
reshaped_q[:, :, : layer.v_head_dim],
|
||||
reshaped_q[:, :, layer.v_head_dim :],
|
||||
reshaped_k[:, :, : layer.v_head_dim],
|
||||
reshaped_k[:, :, layer.v_head_dim :],
|
||||
k_buffer[:, :, : layer.v_head_dim],
|
||||
k_buffer[:, :, layer.v_head_dim :],
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
@@ -825,16 +834,18 @@ def fast_mla_decode_plan(
|
||||
self._sm_scale = sm_scale
|
||||
|
||||
with self.device as device:
|
||||
stream = torch.cuda.current_stream(device).cuda_stream
|
||||
self._cached_module.plan(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_cpu,
|
||||
kv_indptr_cpu,
|
||||
kv_len_arr_cpu,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
causal,
|
||||
stream,
|
||||
)
|
||||
try:
|
||||
# Standard version with just the required arguments (no use_profiler)
|
||||
self._cached_module.plan.default(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_cpu,
|
||||
kv_indptr_cpu,
|
||||
kv_len_arr_cpu,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
causal,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
||||
|
||||
Reference in New Issue
Block a user