初始化项目,由ModelHub XC社区提供模型
Model: AI-ModelScope/falcon-7b-instruct Source: Original Platform
This commit is contained in:
75
ms_wrapper.py
Normal file
75
ms_wrapper.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.text_generation,
|
||||
module_name='falcon-7b-instruct-text-generation-pipe')
|
||||
class falcon7binstructTextGenerationPipeline(Pipeline):
|
||||
def __init__(self, model: Union[Model, str], *args, **kwargs):
|
||||
model = falcon7binstructTextGeneration(model) if isinstance(
|
||||
model, str) else model
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
|
||||
return self.model(inputs)
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_generation,
|
||||
module_name='falcon-7b-instruct')
|
||||
class falcon7binstructTextGeneration(TorchModel):
|
||||
def __init__(self, model_dir=None, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.logger = get_logger()
|
||||
# loading tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.pipeline = transformers.pipeline(
|
||||
"text-generation",
|
||||
model=model_dir,
|
||||
tokenizer=self.tokenizer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
def forward(self, input: Dict) -> Dict[str, Any]:
|
||||
output = {}
|
||||
res = self.infer(input)
|
||||
output['text'] = res
|
||||
return output
|
||||
|
||||
def quantize(self, bits: int):
|
||||
self.model = self.model.quantize(bits)
|
||||
return self
|
||||
|
||||
def infer(self, input):
|
||||
sequences = self.pipeline(
|
||||
input,
|
||||
max_length=200,
|
||||
do_sample=True,
|
||||
top_k=10,
|
||||
num_return_sequences=1,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
return sequences
|
||||
Reference in New Issue
Block a user