Fix sliding window attention and gemma-2 unit tests in CI (#1746)
This commit is contained in:
@@ -342,23 +342,25 @@ class FlashInferIndicesUpdaterDecode:
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
# Sliding window attention
|
||||
paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
|
||||
paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
|
||||
seq_lens,
|
||||
torch.tensor(self.sliding_window_size + 1),
|
||||
)
|
||||
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
||||
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
||||
else:
|
||||
# Full attention
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
paged_kernel_lens_tmp = seq_lens
|
||||
paged_kernel_lens_sum_tmp = seq_lens_sum
|
||||
kv_start_idx_tmp = None
|
||||
|
||||
self.call_begin_forward(
|
||||
decode_wrappers[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
seq_lens_sum,
|
||||
paged_kernel_lens_tmp,
|
||||
paged_kernel_lens_sum_tmp,
|
||||
self.kv_indptr[wrapper_id],
|
||||
kv_start_idx,
|
||||
kv_start_idx_tmp,
|
||||
)
|
||||
|
||||
def update_cross_attention(self):
|
||||
@@ -369,14 +371,16 @@ class FlashInferIndicesUpdaterDecode:
|
||||
wrapper,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
seq_lens_sum,
|
||||
paged_kernel_lens_sum,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
|
||||
@@ -102,8 +102,10 @@ class HFRunner:
|
||||
return False
|
||||
|
||||
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||
# Apply model-specific patches
|
||||
monkey_patch_gemma2_sdpa()
|
||||
|
||||
# Load the model and tokenizer
|
||||
if self.model_type == "generation":
|
||||
self.base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
@@ -128,7 +130,9 @@ class HFRunner:
|
||||
).cuda()
|
||||
else:
|
||||
raise Exception(f"Unrecognized model type {self.model_type}")
|
||||
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
||||
|
||||
# Run forward
|
||||
while True:
|
||||
prompts, max_new_tokens, lora_paths = in_queue.get()
|
||||
if lora_paths is not None:
|
||||
@@ -370,3 +374,18 @@ class SRTRunner:
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.runtime.shutdown()
|
||||
del self.runtime
|
||||
|
||||
|
||||
def monkey_patch_gemma2_sdpa():
|
||||
"""
|
||||
Use sdpa by default to fix the OOM issue.
|
||||
Revert this commit:
|
||||
https://github.com/huggingface/transformers/commit/975b988bfe6e7ebb47390cd9a1556c6888804883#diff-5f76eac6f18f4b491521314c318a9692318feb4d19228e9576cce7bde4240834R660
|
||||
"""
|
||||
from transformers.models.gemma2.modeling_gemma2 import Gemma2PreTrainedModel
|
||||
|
||||
def _check_and_enable_sdpa(config, hard_check_only: bool = False):
|
||||
config._attn_implementation = "sdpa"
|
||||
return config
|
||||
|
||||
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
|
||||
|
||||
Reference in New Issue
Block a user