feat: add fast_decode_plan from flashinfer, flashinfer to 0.4.0rc3 (#10760)
Co-authored-by: Zihao Ye <yezihhhao@gmail.com> Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
@@ -62,7 +62,7 @@ dependencies = [
|
|||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.4.0rc1",
|
"flashinfer_python==0.4.0rc3",
|
||||||
"openai==1.99.1",
|
"openai==1.99.1",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"anthropic>=0.20.0",
|
"anthropic>=0.20.0",
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ srt = [
|
|||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.4.0rc1",
|
"flashinfer_python==0.4.0rc3",
|
||||||
]
|
]
|
||||||
|
|
||||||
blackwell = [
|
blackwell = [
|
||||||
@@ -80,8 +80,8 @@ blackwell = [
|
|||||||
"torchaudio==2.8.0",
|
"torchaudio==2.8.0",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"cuda-python",
|
"cuda-python",
|
||||||
"flashinfer_python==0.4.0rc1",
|
"flashinfer_python==0.4.0rc3",
|
||||||
"nvidia-cutlass-dsl==4.2.1",
|
"nvidia-cutlass-dsl==4.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
|
|||||||
@@ -703,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if server_args.attention_backend == "flashinfer":
|
if server_args.attention_backend == "flashinfer":
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer_python",
|
"flashinfer_python",
|
||||||
"0.4.0rc1",
|
"0.4.0rc3",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ if is_flashinfer_available():
|
|||||||
BatchDecodeWithPagedKVCacheWrapper,
|
BatchDecodeWithPagedKVCacheWrapper,
|
||||||
BatchPrefillWithPagedKVCacheWrapper,
|
BatchPrefillWithPagedKVCacheWrapper,
|
||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
|
fast_decode_plan,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
from flashinfer.decode import _get_range_buf, get_seq_lens
|
from flashinfer.decode import _get_range_buf, get_seq_lens
|
||||||
@@ -842,23 +843,51 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
global_override_indptr_cpu[0] = 0
|
global_override_indptr_cpu[0] = 0
|
||||||
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
||||||
|
|
||||||
wrapper.begin_forward(
|
# Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
|
||||||
kv_indptr,
|
# by checking if it's a partial function with fast_decode_plan as the func
|
||||||
kv_indices,
|
wrapper_uses_fast_decode_plan = (
|
||||||
self.kv_last_page_len[:bs],
|
hasattr(wrapper.begin_forward, "func")
|
||||||
self.num_qo_heads,
|
and wrapper.begin_forward.func == fast_decode_plan
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_dim,
|
|
||||||
1,
|
|
||||||
data_type=self.data_type,
|
|
||||||
q_data_type=self.q_data_type,
|
|
||||||
non_blocking=True,
|
|
||||||
fixed_split_size=fixed_split_size,
|
|
||||||
disable_split_kv=(
|
|
||||||
disable_split_kv if disable_split_kv is not None else False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if wrapper_uses_fast_decode_plan:
|
||||||
|
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
|
||||||
|
wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
1,
|
||||||
|
data_type=self.data_type,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
non_blocking=True,
|
||||||
|
fixed_split_size=fixed_split_size,
|
||||||
|
disable_split_kv=(
|
||||||
|
disable_split_kv if disable_split_kv is not None else False
|
||||||
|
),
|
||||||
|
global_override_indptr_cpu=global_override_indptr_cpu,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# When using original begin_forward, don't pass global_override_indptr_cpu
|
||||||
|
wrapper.begin_forward(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
self.kv_last_page_len[:bs],
|
||||||
|
self.num_qo_heads,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
1,
|
||||||
|
data_type=self.data_type,
|
||||||
|
q_data_type=self.q_data_type,
|
||||||
|
non_blocking=True,
|
||||||
|
fixed_split_size=fixed_split_size,
|
||||||
|
disable_split_kv=(
|
||||||
|
disable_split_kv if disable_split_kv is not None else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
if locally_override:
|
if locally_override:
|
||||||
global_override_indptr_cpu = None
|
global_override_indptr_cpu = None
|
||||||
|
|
||||||
@@ -1328,174 +1357,3 @@ def should_use_tensor_core(
|
|||||||
return gqa_group_size >= 4
|
return gqa_group_size >= 4
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Use as a fast path to override the indptr in flashinfer's plan function
|
|
||||||
# This is used to remove some host-to-device copy overhead.
|
|
||||||
global_override_indptr_cpu = None
|
|
||||||
|
|
||||||
|
|
||||||
def fast_decode_plan(
|
|
||||||
self,
|
|
||||||
indptr: torch.Tensor,
|
|
||||||
indices: torch.Tensor,
|
|
||||||
last_page_len: torch.Tensor,
|
|
||||||
num_qo_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_dim: int,
|
|
||||||
page_size: int,
|
|
||||||
pos_encoding_mode: str = "NONE",
|
|
||||||
window_left: int = -1,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
|
||||||
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
|
||||||
data_type: Optional[Union[str, torch.dtype]] = None,
|
|
||||||
sm_scale: Optional[float] = None,
|
|
||||||
rope_scale: Optional[float] = None,
|
|
||||||
rope_theta: Optional[float] = None,
|
|
||||||
non_blocking: bool = True,
|
|
||||||
fixed_split_size: Optional[int] = None,
|
|
||||||
disable_split_kv: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
|
||||||
Modifications:
|
|
||||||
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
|
||||||
- Remove unnecessary host-to-device copy for the metadata buffers.
|
|
||||||
"""
|
|
||||||
batch_size = len(last_page_len)
|
|
||||||
if logits_soft_cap is None:
|
|
||||||
logits_soft_cap = 0.0
|
|
||||||
|
|
||||||
# Handle data types consistently
|
|
||||||
if data_type is not None:
|
|
||||||
if q_data_type is None:
|
|
||||||
q_data_type = data_type
|
|
||||||
if kv_data_type is None:
|
|
||||||
kv_data_type = data_type
|
|
||||||
elif q_data_type is None:
|
|
||||||
q_data_type = "float16"
|
|
||||||
|
|
||||||
if kv_data_type is None:
|
|
||||||
kv_data_type = q_data_type
|
|
||||||
|
|
||||||
if self.use_tensor_cores:
|
|
||||||
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
|
||||||
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
|
|
||||||
if fixed_split_size is None:
|
|
||||||
fixed_split_size = -1
|
|
||||||
|
|
||||||
if self.is_cuda_graph_enabled:
|
|
||||||
if batch_size != self._fixed_batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
|
||||||
" mismatches the batch size set during initialization {}".format(
|
|
||||||
batch_size, self._fixed_batch_size
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if len(indices) > len(self._paged_kv_indices_buf):
|
|
||||||
raise ValueError(
|
|
||||||
"The size of indices should be less than or equal to the allocated buffer"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._paged_kv_indptr_buf = indptr
|
|
||||||
self._paged_kv_indices_buf = indices
|
|
||||||
self._paged_kv_last_page_len_buf = last_page_len
|
|
||||||
if self.use_tensor_cores:
|
|
||||||
self._qo_indptr_buf = qo_indptr_host.to(
|
|
||||||
self.device, non_blocking=non_blocking
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create empty tensors for dtype info if needed
|
|
||||||
empty_q_data = torch.empty(
|
|
||||||
0,
|
|
||||||
dtype=(
|
|
||||||
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
|
||||||
),
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
empty_kv_cache = torch.empty(
|
|
||||||
0,
|
|
||||||
dtype=(
|
|
||||||
getattr(torch, kv_data_type)
|
|
||||||
if isinstance(kv_data_type, str)
|
|
||||||
else kv_data_type
|
|
||||||
),
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
indptr_host = (
|
|
||||||
global_override_indptr_cpu
|
|
||||||
if global_override_indptr_cpu is not None
|
|
||||||
else indptr.cpu()
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.cuda.device(self.device):
|
|
||||||
|
|
||||||
if self.use_tensor_cores:
|
|
||||||
# ALSO convert last_page_len to CPU
|
|
||||||
if page_size == 1:
|
|
||||||
# When page size is 1, last_page_len is always 1.
|
|
||||||
# Directly construct the host tensor rather than executing a device-to-host copy.
|
|
||||||
last_page_len_host = torch.ones(
|
|
||||||
(batch_size,), dtype=torch.int32, device="cpu"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
last_page_len_host = last_page_len.cpu()
|
|
||||||
|
|
||||||
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Make sure we pass exactly 15 arguments for tensor core version
|
|
||||||
self._plan_info = self._cached_module.plan(
|
|
||||||
self._float_workspace_buffer,
|
|
||||||
self._int_workspace_buffer,
|
|
||||||
self._pin_memory_int_workspace_buffer,
|
|
||||||
qo_indptr_host,
|
|
||||||
indptr_host,
|
|
||||||
kv_lens_arr_host,
|
|
||||||
batch_size, # total_num_rows
|
|
||||||
batch_size,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
page_size,
|
|
||||||
self.is_cuda_graph_enabled,
|
|
||||||
head_dim,
|
|
||||||
head_dim,
|
|
||||||
False, # causal
|
|
||||||
window_left,
|
|
||||||
fixed_split_size,
|
|
||||||
disable_split_kv,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error in standard plan: {e}")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
# Make sure we pass exactly 15 arguments for standard version
|
|
||||||
self._plan_info = self._cached_module.plan(
|
|
||||||
self._float_workspace_buffer,
|
|
||||||
self._int_workspace_buffer,
|
|
||||||
self._pin_memory_int_workspace_buffer,
|
|
||||||
indptr_host,
|
|
||||||
batch_size,
|
|
||||||
num_qo_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
page_size,
|
|
||||||
self.is_cuda_graph_enabled,
|
|
||||||
window_left,
|
|
||||||
logits_soft_cap,
|
|
||||||
head_dim,
|
|
||||||
head_dim,
|
|
||||||
empty_q_data,
|
|
||||||
empty_kv_cache,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Error in standard plan: {e}")
|
|
||||||
|
|
||||||
self._pos_encoding_mode = pos_encoding_mode
|
|
||||||
self._window_left = window_left
|
|
||||||
self._logits_soft_cap = logits_soft_cap
|
|
||||||
self._sm_scale = sm_scale
|
|
||||||
self._rope_scale = rope_scale
|
|
||||||
self._rope_theta = rope_theta
|
|
||||||
|
|||||||
Reference in New Issue
Block a user