[Feat] Add window attention for gemma-2 (#1056)
This commit is contained in:
@@ -295,7 +295,16 @@ class ModelRunner:
|
||||
return c
|
||||
|
||||
def init_flashinfer(self):
|
||||
self.sliding_window_size = (
|
||||
self.model.get_window_size()
|
||||
if hasattr(self.model, "get_window_size")
|
||||
else None
|
||||
)
|
||||
|
||||
if self.server_args.disable_flashinfer:
|
||||
assert (
|
||||
self.sliding_window_size is None
|
||||
), "turn on flashinfer to support window attention"
|
||||
self.flashinfer_prefill_wrapper_ragged = None
|
||||
self.flashinfer_prefill_wrapper_paged = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
@@ -309,20 +318,54 @@ class ModelRunner:
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0], "NHD"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
if self.sliding_window_size is None:
|
||||
self.flashinfer_workspace_buffers = torch.empty(
|
||||
2,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
else:
|
||||
workspace_buffers = torch.empty(
|
||||
4,
|
||||
global_config.flashinfer_workspace_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = []
|
||||
self.flashinfer_prefill_wrapper_paged = []
|
||||
self.flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_ragged.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 1], "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_decode_wrapper.append(
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffers[2 * i + 0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
@@ -358,7 +401,10 @@ class ModelRunner:
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, ForwardMode.DECODE
|
||||
self,
|
||||
batch,
|
||||
ForwardMode.DECODE,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
|
||||
return self.model.forward(
|
||||
@@ -368,7 +414,10 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, forward_mode=ForwardMode.EXTEND
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids, input_metadata.positions, input_metadata
|
||||
@@ -377,7 +426,10 @@ class ModelRunner:
|
||||
@torch.inference_mode()
|
||||
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
||||
input_metadata = InputMetadata.from_schedule_batch(
|
||||
self, batch, forward_mode=ForwardMode.EXTEND
|
||||
self,
|
||||
batch,
|
||||
forward_mode=ForwardMode.EXTEND,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
)
|
||||
return self.model.forward(
|
||||
batch.input_ids,
|
||||
|
||||
Reference in New Issue
Block a user