Use a single workspace for flashinfer (#1077)

This commit is contained in:
Lianmin Zheng
2024-08-14 19:25:37 -07:00
committed by GitHub
parent 6767e2229f
commit 326df4bab2
5 changed files with 16 additions and 18 deletions

View File

@@ -318,28 +318,26 @@ class ModelRunner:
use_tensor_cores = False
if self.sliding_window_size is None:
self.flashinfer_workspace_buffers = torch.empty(
2,
self.flashinfer_workspace_buffer = torch.empty(
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[1], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[0],
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)
else:
self.flashinfer_workspace_buffers = torch.empty(
4,
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
@@ -350,17 +348,17 @@ class ModelRunner:
for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
self.flashinfer_workspace_buffer, "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[2 * i + 0],
self.flashinfer_workspace_buffer,
"NHD",
use_tensor_cores=use_tensor_cores,
)