32 lines
1.5 KiB
Python
32 lines
1.5 KiB
Python
|
|
import torch
|
||
|
|
|
||
|
|
def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs):
|
||
|
|
|
||
|
|
print("✨ using a custom generation method ✨")
|
||
|
|
|
||
|
|
generation_config = generation_config or model.generation_config # default to the model generation config
|
||
|
|
cur_length = input_ids.shape[1]
|
||
|
|
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
||
|
|
|
||
|
|
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
||
|
|
if left_padding is not None:
|
||
|
|
if not isinstance(left_padding, int) or left_padding < 0:
|
||
|
|
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
||
|
|
|
||
|
|
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
||
|
|
if pad_token is None:
|
||
|
|
raise ValueError("pad_token is not defined")
|
||
|
|
batch_size = input_ids.shape[0]
|
||
|
|
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
||
|
|
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
||
|
|
cur_length = input_ids.shape[1]
|
||
|
|
|
||
|
|
# Simple greedy decoding loop
|
||
|
|
while cur_length < max_length:
|
||
|
|
logits = model(input_ids).logits
|
||
|
|
next_token_logits = logits[:, -1, :]
|
||
|
|
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||
|
|
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
||
|
|
cur_length += 1
|
||
|
|
|
||
|
|
return input_ids
|