Use a single workspace for flashinfer (#1077)
This commit is contained in:
@@ -27,7 +27,7 @@ class GlobalConfig:
|
||||
# Runtime constants: others
|
||||
self.num_continue_decode_steps = 10
|
||||
self.retract_decode_steps = 20
|
||||
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
||||
self.flashinfer_workspace_size = 384 * 1024 * 1024
|
||||
|
||||
# Output tokenization configs
|
||||
self.skip_special_tokens_in_output = True
|
||||
|
||||
@@ -120,13 +120,13 @@ class CudaGraphRunner:
|
||||
)
|
||||
if model_runner.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffers[0]
|
||||
self.model_runner.flashinfer_workspace_buffer
|
||||
)
|
||||
else:
|
||||
self.flashinfer_workspace_buffers = [
|
||||
self.model_runner.flashinfer_workspace_buffers[0],
|
||||
self.model_runner.flashinfer_workspace_buffers[2],
|
||||
]
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffer
|
||||
)
|
||||
|
||||
self.flashinfer_kv_indptr = [
|
||||
self.flashinfer_kv_indptr,
|
||||
self.flashinfer_kv_indptr.clone(),
|
||||
@@ -200,7 +200,7 @@ class CudaGraphRunner:
|
||||
for i in range(2):
|
||||
flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[i],
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
|
||||
@@ -318,28 +318,26 @@ class ModelRunner:
|
||||
use_tensor_cores = False
|
||||
|
||||
if self.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
2,
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0], "NHD"
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0],
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
4,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
@@ -350,17 +348,17 @@ class ModelRunner:
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_ragged.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[2 * i + 0],
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
|
||||
@@ -381,7 +381,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.1.4",
|
||||
"0.1.5",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
Reference in New Issue
Block a user