replace get_max_length

This commit is contained in:
x54-729
2025-03-13 15:04:55 +08:00
parent b7229bee87
commit 571400500e

View File

@@ -1081,7 +1081,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
min_dtype = torch.finfo(dtype).min min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1] sequence_length = input_tensor.shape[1]
if using_static_cache: if using_static_cache:
target_length = past_key_values.get_max_length() target_length = past_key_values.get_max_cache_shape()
else: else:
target_length = ( target_length = (
attention_mask.shape[-1] attention_mask.shape[-1]
@@ -1274,8 +1274,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
if past_key_values.get_max_length() is not None if past_key_values.get_max_cache_shape() is not None
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)