From 571400500eee14d1c36e049cab59d39ece6f3d81 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Thu, 13 Mar 2025 15:04:55 +0800 Subject: [PATCH] replace get_max_length --- modeling_internlm2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modeling_internlm2.py b/modeling_internlm2.py index 3c95101..d88cf30 100644 --- a/modeling_internlm2.py +++ b/modeling_internlm2.py @@ -1081,7 +1081,7 @@ class InternLM2Model(InternLM2PreTrainedModel): min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: - target_length = past_key_values.get_max_length() + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -1274,8 +1274,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel): if isinstance(past_key_values, Cache): past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None + torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device) + if past_key_values.get_max_cache_shape() is not None else None ) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)