初始化项目,由ModelHub XC社区提供模型
Model: PAIXAI/Astrid-1B-CPU Source: Original Platform
This commit is contained in:
42
h2oai_pipeline.py
Normal file
42
h2oai_pipeline.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
STYLE = "<|prompt|>{instruction}<|endoftext|><|answer|>"
|
||||
|
||||
|
||||
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt = STYLE
|
||||
|
||||
def preprocess(
|
||||
self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs
|
||||
):
|
||||
prompt_text = self.prompt.format(instruction=prompt_text)
|
||||
return super().preprocess(
|
||||
prompt_text,
|
||||
prefix=prefix,
|
||||
handle_long_generation=handle_long_generation,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
model_outputs,
|
||||
return_type=ReturnType.FULL_TEXT,
|
||||
clean_up_tokenization_spaces=True,
|
||||
):
|
||||
records = super().postprocess(
|
||||
model_outputs,
|
||||
return_type=return_type,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
)
|
||||
for rec in records:
|
||||
rec["generated_text"] = (
|
||||
rec["generated_text"]
|
||||
.split("<|answer|>")[1]
|
||||
.strip()
|
||||
.split("<|prompt|>")[0]
|
||||
.strip()
|
||||
)
|
||||
return records
|
||||
Reference in New Issue
Block a user