初始化项目,由ModelHub XC社区提供模型
Model: transformers-community/custom_generate_bad_requirements Source: Original Platform
This commit is contained in:
32
custom_generate/generate.py
Normal file
32
custom_generate/generate.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
|
||||
def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs):
|
||||
|
||||
print("✨ using a custom generation method ✨")
|
||||
|
||||
generation_config = generation_config or model.generation_config # default to the model generation config
|
||||
cur_length = input_ids.shape[1]
|
||||
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
||||
|
||||
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
||||
if left_padding is not None:
|
||||
if not isinstance(left_padding, int) or left_padding < 0:
|
||||
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
||||
|
||||
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
||||
if pad_token is None:
|
||||
raise ValueError("pad_token is not defined")
|
||||
batch_size = input_ids.shape[0]
|
||||
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
||||
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
||||
cur_length = input_ids.shape[1]
|
||||
|
||||
# Simple greedy decoding loop
|
||||
while cur_length < max_length:
|
||||
logits = model(input_ids).logits
|
||||
next_token_logits = logits[:, -1, :]
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||||
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
||||
cur_length += 1
|
||||
|
||||
return input_ids
|
||||
3
custom_generate/requirements.txt
Normal file
3
custom_generate/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
foo
|
||||
bar==0.0.0
|
||||
torch>=99.0
|
||||
Reference in New Issue
Block a user