[Fix] Compatibility of window attention and cuda graph (#1090)

This commit is contained in:
Ying Sheng
2024-08-14 10:37:01 -07:00
committed by GitHub
parent a34dd86a7d
commit 96a2093ef0
7 changed files with 70 additions and 39 deletions

View File

@@ -187,6 +187,11 @@ class ModelRunner:
scheduler_config=None,
cache_config=None,
)
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures
)
@@ -295,12 +300,6 @@ class ModelRunner:
return c
def init_flashinfer(self):
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
@@ -339,7 +338,7 @@ class ModelRunner:
use_tensor_cores=use_tensor_cores,
)
else:
workspace_buffers = torch.empty(
self.flashinfer_workspace_buffers = torch.empty(
4,
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
@@ -351,17 +350,17 @@ class ModelRunner:
for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[2 * i + 0], "NHD"
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 1], "NHD"
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 0],
self.flashinfer_workspace_buffers[2 * i + 0],
"NHD",
use_tensor_cores=use_tensor_cores,
)
@@ -404,7 +403,6 @@ class ModelRunner:
self,
batch,
ForwardMode.DECODE,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
@@ -417,7 +415,6 @@ class ModelRunner:
self,
batch,
forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
@@ -429,7 +426,6 @@ class ModelRunner:
self,
batch,
forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
batch.input_ids,