[Feat] Add window attention for gemma-2 (#1056)

This commit is contained in:
Ying Sheng
2024-08-13 17:01:26 -07:00
committed by GitHub
parent ad3e4f1619
commit 0909bb0d2f
11 changed files with 320 additions and 127 deletions

View File

@@ -17,9 +17,12 @@ limitations under the License.
import argparse
import dataclasses
import logging
import random
from typing import List, Optional, Union
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ServerArgs:
@@ -446,6 +449,15 @@ class ServerArgs:
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
if "gemma-2" in self.model_path.lower():
logger.info(
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
)
self.disable_radix_cache = True
self.disable_regex_jump_forward = True
self.disable_flashinfer = False
self.disable_cuda_graph = True
self.chunked_prefill_size = None
@dataclasses.dataclass