42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
from transformers import TextGenerationPipeline
|
|
from transformers.pipelines.text_generation import ReturnType
|
|
|
|
STYLE = "<|prompt|>{instruction}<|endoftext|><|answer|>"
|
|
|
|
|
|
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.prompt = STYLE
|
|
|
|
def preprocess(
|
|
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
|
|
):
|
|
prompt_text = self.prompt.format(instruction=prompt_text)
|
|
return super().preprocess(
|
|
prompt_text,
|
|
prefix=prefix,
|
|
handle_long_generation=handle_long_generation,
|
|
**generate_kwargs,
|
|
)
|
|
|
|
def postprocess(
|
|
self,
|
|
model_outputs,
|
|
return_type=ReturnType.FULL_TEXT,
|
|
clean_up_tokenization_spaces=True,
|
|
):
|
|
records = super().postprocess(
|
|
model_outputs,
|
|
return_type=return_type,
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
)
|
|
for rec in records:
|
|
rec["generated_text"] = (
|
|
rec["generated_text"]
|
|
.split("<|answer|>")[1]
|
|
.strip()
|
|
.split("<|prompt|>")[0]
|
|
.strip()
|
|
)
|
|
return records |