[Fix] Compatibility of window attention and cuda graph (#1090)
This commit is contained in:
@@ -187,6 +187,11 @@ class ModelRunner:
|
||||
scheduler_config=None,
|
||||
cache_config=None,
|
||||
)
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
else None
|
||||
)
|
||||
self.is_generation = is_generation_model(
|
||||
self.model_config.hf_config.architectures
|
||||
)
|
||||
@@ -295,12 +300,6 @@ class ModelRunner:
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
else None
|
||||
)
|
||||
|
||||
if self.server_args.disable_flashinfer:
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
@@ -339,7 +338,7 @@ class ModelRunner:
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
workspace_buffers = torch.empty(
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
4,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
@@ -351,17 +350,17 @@ class ModelRunner:
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_ragged.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0], "NHD"
|
||||
self.flashinfer_workspace_buffers[2 * i + 0], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 1], "NHD"
|
||||
self.flashinfer_workspace_buffers[2 * i + 1], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0],
|
||||
self.flashinfer_workspace_buffers[2 * i + 0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
@@ -404,7 +403,6 @@ class ModelRunner:
|
||||
self,
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
|
||||
return self.model.forward(
|
||||
@@ -417,7 +415,6 @@ class ModelRunner:
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -429,7 +426,6 @@ class ModelRunner:
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
|
||||
Reference in New Issue
Block a user