replace get_max_length
This commit is contained in:
@@ -1081,7 +1081,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|||||||
min_dtype = torch.finfo(dtype).min
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
if using_static_cache:
|
||||||
target_length = past_key_values.get_max_length()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1274,8 +1274,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|||||||
if isinstance(past_key_values, Cache):
|
if isinstance(past_key_values, Cache):
|
||||||
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||||
max_cache_length = (
|
max_cache_length = (
|
||||||
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
|
||||||
if past_key_values.get_max_length() is not None
|
if past_key_values.get_max_cache_shape() is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
||||||
|
|||||||
Reference in New Issue
Block a user