Files
Vicuna-7B/ms_wrapper.py

78 lines
3.0 KiB
Python
Raw Normal View History

import os
from typing import Any, Dict, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
from modelscope.models.base import Model, TorchModel
from modelscope.models.builder import MODELS
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.nlp.text_generation_pipeline import \
TextGenerationPipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
Vicuna_PROMPT_FORMAT = "### Human:\n{prompt} \n ### Assistant:\n"
@PIPELINES.register_module(Tasks.text_generation,
module_name='Vicuna7b-text-generation-pipe')
class Vicuna7bTextGenerationPipeline(TextGenerationPipeline):
def __init__(self, model: Union[Model, str], *args, **kwargs):
model = Vicuna7bTextGeneration(model) if isinstance(model,
str) else model
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
def _sanitize_parameters(self, **pipeline_parameters):
return {}, pipeline_parameters, {}
# define the forward pass
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
return self.model(inputs, **forward_params)
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input
@MODELS.register_module(Tasks.text_generation, module_name='Vicuna7b')
class Vicuna7bTextGeneration(TorchModel):
def __init__(self, model_dir=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.logger = get_logger()
# loading tokenizer
self.tokenizer = LlamaTokenizer.from_pretrained(model_dir,
use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
low_cpu_mem_usage=True,
device_map="auto",
torch_dtype=torch.float16)
self.model = self.model.eval()
def forward(self, input: Dict, *args, **kwargs) -> Dict[str, Any]:
output = {}
res = self.infer(input, **kwargs)
output['text'] = res
return output
def quantize(self, bits: int):
self.model = self.model.quantize(bits)
return self
def infer(self, input, max_new_tokens=1024, **kwargs):
kwargs['max_new_tokens'] = max_new_tokens
device = self.model.device
input = Vicuna_PROMPT_FORMAT.format(prompt=input)
input_ids = self.tokenizer(input,
return_tensors="pt").input_ids.to(device)
output_ids = self.model.generate(input_ids, **kwargs)
output_ids = output_ids[0][len(input_ids[0]):]
outputs = self.tokenizer.decode(output_ids,
skip_special_tokens=True).strip()
return outputs