[Feat] Add window attention for gemma-2 (#1056)
This commit is contained in:
@@ -17,9 +17,12 @@ limitations under the License.
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
@@ -446,6 +449,15 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
if "gemma-2" in self.model_path.lower():
|
||||
logger.info(
|
||||
f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
|
||||
)
|
||||
self.disable_radix_cache = True
|
||||
self.disable_regex_jump_forward = True
|
||||
self.disable_flashinfer = False
|
||||
self.disable_cuda_graph = True
|
||||
self.chunked_prefill_size = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
||||
Reference in New Issue
Block a user