Files
bloom-560m/ms_wrapper.py
ModelHub XC 52b20216d1 初始化项目,由ModelHub XC社区提供模型
Model: AI-ModelScope/bloom-560m
Source: Original Platform
2026-05-14 16:29:50 +08:00

64 lines
2.3 KiB
Python

# Copyright (c) 2022 Zhipu.AI
from typing import List, Union
from modelscope.preprocessors import Preprocessor
import torch
import os
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.pipelines.base import InputModel, Pipeline
from modelscope.models.builder import MODELS
from modelscope.utils.logger import get_logger
from typing import Union, Dict, Any
from modelscope.models.base import Model, TorchModel
from transformers import AutoModelForCausalLM, AutoTokenizer
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
@PIPELINES.register_module(Tasks.text_generation, module_name='Bloom560m-text-generation-pipe')
class Bloom560mTextGenerationPipeline(Pipeline):
def __init__(
self,
model: Union[Model, str],
*args,
**kwargs):
model = Bloom560mTextGeneration(model) if isinstance(model, str) else model
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
# define the forward pass
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
return self.model(inputs)
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input
@MODELS.register_module(Tasks.text_generation, module_name='Bloom560m')
class Bloom560mTextGeneration(TorchModel):
def __init__(self, model_dir=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.logger = get_logger()
# loading tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto")
self.model = self.model.eval()
def forward(self,input: Dict, *args, **kwargs) -> Dict[str, Any]:
output = {}
res = self.infer(input)
output['text'] = res
return output
def quantize(self, bits: int):
self.model = self.model.quantize(bits)
return self
def infer(self, input):
device = self.model.device
input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device)
logits = self.model.generate(input_ids, num_beams=1, max_length=512)
out = self.tokenizer.decode(logits[0].tolist())
return out