# Copyright (c) 2022 Zhipu.AI from typing import List, Union from modelscope.preprocessors import Preprocessor import torch import os from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks from modelscope.pipelines.base import InputModel, Pipeline from modelscope.models.builder import MODELS from modelscope.utils.logger import get_logger from typing import Union, Dict, Any from modelscope.models.base import Model, TorchModel from transformers import AutoModelForCausalLM, AutoTokenizer os.environ['CUDA_VISIBLE_DEVICES'] = "0" @PIPELINES.register_module(Tasks.text_generation, module_name='Bloom560m-text-generation-pipe') class Bloom560mTextGenerationPipeline(Pipeline): def __init__( self, model: Union[Model, str], *args, **kwargs): model = Bloom560mTextGeneration(model) if isinstance(model, str) else model super().__init__(model=model, **kwargs) def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: return inputs # define the forward pass def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]: return self.model(inputs) # format the outputs from pipeline def postprocess(self, input, **kwargs) -> Dict[str, Any]: return input @MODELS.register_module(Tasks.text_generation, module_name='Bloom560m') class Bloom560mTextGeneration(TorchModel): def __init__(self, model_dir=None, *args, **kwargs): super().__init__(model_dir, *args, **kwargs) self.logger = get_logger() # loading tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto") self.model = self.model.eval() def forward(self,input: Dict, *args, **kwargs) -> Dict[str, Any]: output = {} res = self.infer(input) output['text'] = res return output def quantize(self, bits: int): self.model = self.model.quantize(bits) return self def infer(self, input): device = self.model.device input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device) logits = self.model.generate(input_ids, num_beams=1, max_length=512) out = self.tokenizer.decode(logits[0].tolist()) return out