Fix sliding window attention and gemma-2 unit tests in CI (#1746)

This commit is contained in:
Lianmin Zheng
2024-10-21 13:47:12 -07:00
committed by GitHub
parent e68b9e7667
commit 00611286a1
4 changed files with 35 additions and 14 deletions

View File

@@ -102,8 +102,10 @@ class HFRunner:
return False
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
# Apply model-specific patches
monkey_patch_gemma2_sdpa()
# Load the model and tokenizer
if self.model_type == "generation":
self.base_model = AutoModelForCausalLM.from_pretrained(
model_path,
@@ -128,7 +130,9 @@ class HFRunner:
).cuda()
else:
raise Exception(f"Unrecognized model type {self.model_type}")
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
# Run forward
while True:
prompts, max_new_tokens, lora_paths = in_queue.get()
if lora_paths is not None:
@@ -370,3 +374,18 @@ class SRTRunner:
def __exit__(self, exc_type, exc_value, traceback):
self.runtime.shutdown()
del self.runtime
def monkey_patch_gemma2_sdpa():
"""
Use sdpa by default to fix the OOM issue.
Revert this commit:
https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
"""
from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel
def _check_and_enable_sdpa(config, hard_check_only: bool = False):
config._attn_implementation = "sdpa"
return config
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)