Files
Ziya-LLaMA-13B-v1.1/ms_wrapper.py

68 lines
2.5 KiB
Python
Raw Normal View History

import os
from typing import Union, Dict, Any
from modelscope.pipelines.builder import PIPELINES
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.pipelines.base import Pipeline
from modelscope.models.base import Model, TorchModel
from modelscope.utils.logger import get_logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaForCausalLM
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
@PIPELINES.register_module(Tasks.text_generation, module_name='Ziya-LLaMA-13B-v1-text-generation-pipe')
class ZiyaLLaMA13Bv1TextGenerationPipeline(Pipeline):
def __init__(
self,
model: Union[Model, str],
*args,
**kwargs):
model = ZiyaLLaMA13Bv1TextGeneration(model) if isinstance(model, str) else model
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
# define the forward pass
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
return self.model(inputs)
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input
@MODELS.register_module(Tasks.text_generation, module_name='Ziya-LLaMA-13B-v1')
class ZiyaLLaMA13Bv1TextGeneration(TorchModel):
def __init__(self, model_dir=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.logger = get_logger()
# loading tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = LlamaForCausalLM.from_pretrained(model_dir, device_map="auto")
self.model = self.model.eval()
def forward(self,input: Dict) -> Dict[str, Any]:
output = {}
res = self.infer(input)
output['text'] = res
return output
def quantize(self, bits: int):
self.model = self.model.quantize(bits)
return self
def infer(self, input):
device = self.model.device
input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device)
logits = self.model.generate(input_ids, max_new_tokens=1024, do_sample = True,
top_p = 0.85, temperature = 1.0, repetition_penalty=1.,
eos_token_id=2,bos_token_id=1,pad_token_id=0)
out = self.tokenizer.batch_decode(logits)[0]
return out