73 lines
2.9 KiB
Python
73 lines
2.9 KiB
Python
import os
|
|
from typing import Any, Dict, Union
|
|
|
|
import torch
|
|
import transformers
|
|
from modelscope.models.base import Model, TorchModel
|
|
from modelscope.models.builder import MODELS
|
|
from modelscope.pipelines.base import Pipeline
|
|
from modelscope.pipelines.builder import PIPELINES
|
|
from modelscope.pipelines.nlp.text_generation_pipeline import \
|
|
TextGenerationPipeline
|
|
from modelscope.utils.constant import Tasks
|
|
from modelscope.utils.logger import get_logger
|
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
|
|
|
@PIPELINES.register_module(Tasks.text_generation,
|
|
module_name='chinese-alpaca-plus-13b-hf-text-generation-pipe')
|
|
class chinesealpacaplus13bhfTextGenerationPipeline(TextGenerationPipeline):
|
|
def __init__(self, model: Union[Model, str], *args, **kwargs):
|
|
model = chinesealpacaplus13bhfTextGeneration(model) if isinstance(model,
|
|
str) else model
|
|
super().__init__(model=model, **kwargs)
|
|
|
|
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
|
return inputs
|
|
|
|
def _sanitize_parameters(self, **pipeline_parameters):
|
|
return {},pipeline_parameters,{}
|
|
|
|
# define the forward pass
|
|
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
|
|
return self.model(inputs, **forward_params)
|
|
|
|
# format the outputs from pipeline
|
|
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
|
return input
|
|
|
|
|
|
@MODELS.register_module(Tasks.text_generation, module_name='chinese-alpaca-plus-13b-hf')
|
|
class chinesealpacaplus13bhfTextGeneration(TorchModel):
|
|
def __init__(self, model_dir=None, *args, **kwargs):
|
|
super().__init__(model_dir, *args, **kwargs)
|
|
self.logger = get_logger()
|
|
# loading tokenizer
|
|
self.tokenizer = LlamaTokenizer.from_pretrained(model_dir,
|
|
use_fast=False)
|
|
self.model = LlamaForCausalLM.from_pretrained(
|
|
model_dir, low_cpu_mem_usage=True, device_map="auto",
|
|
torch_dtype=torch.float16)
|
|
self.model = self.model.eval()
|
|
|
|
def forward(self, input: Dict, *args, **kwargs) -> Dict[str, Any]:
|
|
output = {}
|
|
res = self.infer(input,**kwargs)
|
|
output['text'] = res
|
|
return output
|
|
|
|
def quantize(self, bits: int):
|
|
self.model = self.model.quantize(bits)
|
|
return self
|
|
|
|
def infer(self, input, max_new_tokens=1024, **kwargs):
|
|
kwargs['max_new_tokens'] = max_new_tokens
|
|
device = self.model.device
|
|
input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device)
|
|
output_ids = self.model.generate(input_ids,**kwargs)
|
|
output_ids = output_ids[0][len(input_ids[0]):]
|
|
|
|
outputs = self.tokenizer.decode(output_ids,
|
|
skip_special_tokens=True).strip()
|
|
return outputs
|