5.1 KiB
用于生成的工具
此页面列出了所有由 [~generation.GenerationMixin.generate]。
生成输出
[~generation.GenerationMixin.generate] 的输出是 [~utils.ModelOutput] 的一个子类的实例。这个输出是一种包含 [~generation.GenerationMixin.generate] 返回的所有信息数据结构,但也可以作为元组或字典使用。
这里是一个例子:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
generation_output 的对象是 [~generation.GenerateDecoderOnlyOutput] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
sequences: 生成的tokens序列scores(可选): 每个生成步骤的语言建模头的预测分数hidden_states(可选): 每个生成步骤模型的hidden statesattentions(可选): 每个生成步骤模型的注意力权重
在这里,由于我们传递了 output_scores=True,我们具有 scores 属性。但我们没有 hidden_states 和 attentions,因为没有传递 output_hidden_states=True 或 output_attentions=True。
您可以像通常一样访问每个属性,如果该属性未被模型返回,则将获得 None。例如,在这里 generation_output.scores 是语言建模头的所有生成预测分数,而 generation_output.attentions 为 None。
当我们将 generation_output 对象用作元组时,它只保留非 None 值的属性。例如,在这里它有两个元素,loss 然后是 logits,所以
generation_output[:2]
将返回元组(generation_output.sequences, generation_output.scores)。
当我们将generation_output对象用作字典时,它只保留非None的属性。例如,它有两个键,分别是sequences和scores。
我们在此记录所有输出类型。
PyTorch
autodoc generation.GenerateDecoderOnlyOutput
autodoc generation.GenerateEncoderDecoderOutput
autodoc generation.GenerateBeamDecoderOnlyOutput
autodoc generation.GenerateBeamEncoderDecoderOutput
LogitsProcessor
[LogitsProcessor] 可以用于修改语言模型头的预测分数以进行生成
PyTorch
autodoc AlternatingCodebooksLogitsProcessor - call
autodoc ClassifierFreeGuidanceLogitsProcessor - call
autodoc EncoderNoRepeatNGramLogitsProcessor - call
autodoc EncoderRepetitionPenaltyLogitsProcessor - call
autodoc EpsilonLogitsWarper - call
autodoc EtaLogitsWarper - call
autodoc ExponentialDecayLengthPenalty - call
autodoc ForcedBOSTokenLogitsProcessor - call
autodoc ForcedEOSTokenLogitsProcessor - call
autodoc InfNanRemoveLogitsProcessor - call
autodoc LogitNormalization - call
autodoc LogitsProcessor - call
autodoc LogitsProcessorList - call
autodoc MinLengthLogitsProcessor - call
autodoc MinNewTokensLengthLogitsProcessor - call
autodoc NoBadWordsLogitsProcessor - call
autodoc NoRepeatNGramLogitsProcessor - call
autodoc PrefixConstrainedLogitsProcessor - call
autodoc RepetitionPenaltyLogitsProcessor - call
autodoc SequenceBiasLogitsProcessor - call
autodoc SuppressTokensAtBeginLogitsProcessor - call
autodoc SuppressTokensLogitsProcessor - call
autodoc TemperatureLogitsWarper - call
autodoc TopKLogitsWarper - call
autodoc TopPLogitsWarper - call
autodoc TypicalLogitsWarper - call
autodoc UnbatchedClassifierFreeGuidanceLogitsProcessor - call
autodoc WhisperTimeStampLogitsProcessor - call
StoppingCriteria
可以使用[StoppingCriteria]来更改停止生成的时间(除了EOS token以外的方法)。请注意,这仅适用于我们的PyTorch实现。
autodoc StoppingCriteria - call
autodoc StoppingCriteriaList - call
autodoc MaxLengthCriteria - call
autodoc MaxTimeCriteria - call
Streamers
autodoc TextStreamer
autodoc TextIteratorStreamer