[Feat] Add window attention for gemma-2 (#1056)

This commit is contained in:
Ying Sheng
2024-08-13 17:01:26 -07:00
committed by GitHub
parent ad3e4f1619
commit 0909bb0d2f
11 changed files with 320 additions and 127 deletions

View File

@@ -295,7 +295,16 @@ class ModelRunner:
return c
def init_flashinfer(self):
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
else None
)
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
), "turn on flashinfer to support window attention"
self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = None
self.flashinfer_decode_wrapper = None
@@ -309,20 +318,54 @@ class ModelRunner:
else:
use_tensor_cores = False
self.flashinfer_workspace_buffers = torch.empty(
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
)
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], "NHD"
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[0],
"NHD",
use_tensor_cores=use_tensor_cores,
)
if self.sliding_window_size is None:
self.flashinfer_workspace_buffers = torch.empty(
2,
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], "NHD"
)
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[0],
"NHD",
use_tensor_cores=use_tensor_cores,
)
else:
workspace_buffers = torch.empty(
4,
global_config.flashinfer_workspace_size,
dtype=torch.uint8,
device="cuda",
)
self.flashinfer_prefill_wrapper_ragged = []
self.flashinfer_prefill_wrapper_paged = []
self.flashinfer_decode_wrapper = []
for i in range(2):
self.flashinfer_prefill_wrapper_ragged.append(
BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[2 * i + 0], "NHD"
)
)
self.flashinfer_prefill_wrapper_paged.append(
BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 1], "NHD"
)
)
self.flashinfer_decode_wrapper.append(
BatchDecodeWithPagedKVCacheWrapper(
workspace_buffers[2 * i + 0],
"NHD",
use_tensor_cores=use_tensor_cores,
)
)
def init_cuda_graphs(self):
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
@@ -358,7 +401,10 @@ class ModelRunner:
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
self, batch, ForwardMode.DECODE
self,
batch,
ForwardMode.DECODE,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
@@ -368,7 +414,10 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(
self, batch, forward_mode=ForwardMode.EXTEND
self,
batch,
forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
@@ -377,7 +426,10 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(
self, batch, forward_mode=ForwardMode.EXTEND
self,
batch,
forward_mode=ForwardMode.EXTEND,
sliding_window_size=self.sliding_window_size,
)
return self.model.forward(
batch.input_ids,