import os from fastapi import FastAPI, Query from fastapi.responses import PlainTextResponse from pydantic import BaseModel from typing import List, Any import uvicorn from modelscope import snapshot_download from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import logger log = logger.get_logger(__file__) app = FastAPI() status = "Running" translator = None device = None model_type = None MODEL_TYPE = ("nllb200", "small100", "mbart", "opus_mt") MODEL_DIR = "/workspace/model" class TranslateRequest(BaseModel): Text: str @app.on_event("startup") def load_model(): log.info("loading model") global status, translator, device, model_type model_type = extract_model_type() log.info(f"model_type={model_type}") fetch_model() tokenizer, model = get_tokenizer_model() #log.info(f"tokenizer={tokenizer}, model={model}") model = model.to("cuda") translator = pipeline(task="translation", model=model, tokenizer=tokenizer, device="cuda", use_cache=True) warm_up() status = "Success" log.info("model loaded successfully") def fetch_model(): mn = os.environ.get("MODEL_NAME", "") log.info(f"model_name={mn}") os.makedirs(os.path.dirname(MODEL_DIR), exist_ok=True) snapshot_download(mn, local_dir=MODEL_DIR) def translator_helper(text): source_lang = "zh" target_lang = "en" if model_type == "nllb200": source_lang = "zho_Hans" target_lang = "eng_Latn" if model_type == "mbart": source_lang = "zh_CN" target_lang = "en_XX" if model_type == "opus_mt": source_lang = "eng" target_lang = "zho" output = translator(text, src_lang=source_lang, tgt_lang=target_lang) log.info(f"model_type={model_type}, src_lang={source_lang}, tgt_lang={target_lang}, output={output}") return output def get_tokenizer_model(): if model_type == "small100": from tokenization_small100 import SMALL100Tokenizer tokenizer = SMALL100Tokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR) else: tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR) return tokenizer, model def extract_model_type(): mt = os.environ.get("MODEL_TYPE", "") log.info(f"model_type_input={mt}") model = mt.lower() if model not in MODEL_TYPE: log.error(f"model_type {model} is not supported") os._exit(1) return model def warm_up(): log.info("warming up...") warmup_test = translator_helper("今天的天气非常好") log.info(f"warm up completed! model_type={model_type}, response={warmup_test}") return warmup_test @app.get("/v1/get_status") async def get_status(): ret = { "data": { "status": status } } return ret @app.post("/v1/translate") async def translate( payload: List[TranslateRequest], ): if not payload: return PlainTextResponse(text="Information missing", status_code=400) results = [] texts = [] for trans_request in payload: translations = [] texts.append(trans_request.Text) outputs = translator_helper(texts) for i in range(0, len(texts)): translations.append({ "origin_text": texts[i], "translated": outputs[i]['translation_text'] }) results.append({ "translations": translations }) return results if __name__ == '__main__': uvicorn.run("fastapi_translate:app", host="0.0.0.0", port=80, workers=1, access_log=False)