Files
Ziya-LLaMA-13B-v1/ms_wrapper.py
ModelHub XC 8c6c349b99 初始化项目,由ModelHub XC社区提供模型
Model: Fengshenbang/Ziya-LLaMA-13B-v1
Source: Original Platform
2026-05-25 16:25:14 +08:00

71 lines
2.6 KiB
Python

import os
from typing import Union, Dict, Any
from modelscope.pipelines.builder import PIPELINES
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.nlp.text_generation_pipeline import TextGenerationPipeline
from modelscope.models.base import Model, TorchModel
from modelscope.utils.logger import get_logger
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import LlamaForCausalLM
@PIPELINES.register_module(Tasks.text_generation, module_name='Ziya-LLaMA-13B-v1-text-generation-pipe')
class ZiyaLLaMA13Bv1TextGenerationPipeline(TextGenerationPipeline):
def __init__(
self,
model: Union[Model, str],
*args,
**kwargs):
model = ZiyaLLaMA13Bv1TextGeneration(model) if isinstance(model, str) else model
super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
def _sanitize_parameters(self, **pipeline_parameters):
return {},pipeline_parameters,{}
# define the forward pass
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
return self.model(inputs, **forward_params)
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input
@MODELS.register_module(Tasks.text_generation, module_name='Ziya-LLaMA-13B-v1')
class ZiyaLLaMA13Bv1TextGeneration(TorchModel):
def __init__(self, model_dir=None, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.logger = get_logger()
# loading tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = LlamaForCausalLM.from_pretrained(model_dir, device_map="auto")
self.model = self.model.eval()
def forward(self,input: Dict, *args, **kwargs) -> Dict[str, Any]:
output = {}
res = self.infer(input, **kwargs)
res = res.replace(r"<s>","")
res = res.replace(r"</s>","")
res = res[len(input)+3:]
output['text'] = res
return output
def quantize(self, bits: int):
self.model = self.model.quantize(bits)
return self
def infer(self, input, max_new_tokens=1024, **kwargs):
device = self.model.device
kwargs['max_new_tokens'] = max_new_tokens
input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device)
logits = self.model.generate(input_ids, **kwargs)
out = self.tokenizer.batch_decode(logits)[0]
return out