141 lines
3.6 KiB
Python
141 lines
3.6 KiB
Python
import os
|
|
from fastapi import FastAPI, Query
|
|
from fastapi.responses import PlainTextResponse
|
|
from pydantic import BaseModel
|
|
|
|
from typing import List, Any
|
|
|
|
import uvicorn
|
|
|
|
from modelscope import snapshot_download
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
import logger
|
|
log = logger.get_logger(__file__)
|
|
|
|
app = FastAPI()
|
|
status = "Running"
|
|
translator = None
|
|
device = None
|
|
model_type = None
|
|
|
|
MODEL_TYPE = ("nllb200", "small100", "mbart", "opus_mt")
|
|
MODEL_DIR = "/workspace/model"
|
|
|
|
class TranslateRequest(BaseModel):
|
|
Text: str
|
|
|
|
@app.on_event("startup")
|
|
def load_model():
|
|
log.info("loading model")
|
|
global status, translator, device, model_type
|
|
|
|
model_type = extract_model_type()
|
|
log.info(f"model_type={model_type}")
|
|
|
|
fetch_model()
|
|
|
|
tokenizer, model = get_tokenizer_model()
|
|
#log.info(f"tokenizer={tokenizer}, model={model}")
|
|
model = model.to("cuda")
|
|
|
|
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, device="cuda", use_cache=True)
|
|
warm_up()
|
|
|
|
status = "Success"
|
|
log.info("model loaded successfully")
|
|
|
|
def fetch_model():
|
|
mn = os.environ.get("MODEL_NAME", "")
|
|
log.info(f"model_name={mn}")
|
|
|
|
os.makedirs(os.path.dirname(MODEL_DIR), exist_ok=True)
|
|
snapshot_download(mn, local_dir=MODEL_DIR)
|
|
|
|
def translator_helper(text):
|
|
source_lang = "zh"
|
|
target_lang = "en"
|
|
|
|
if model_type == "nllb200":
|
|
source_lang = "zho_Hans"
|
|
target_lang = "eng_Latn"
|
|
if model_type == "mbart":
|
|
source_lang = "zh_CN"
|
|
target_lang = "en_XX"
|
|
if model_type == "opus_mt":
|
|
source_lang = "eng"
|
|
target_lang = "zho"
|
|
|
|
output = translator(text, src_lang=source_lang, tgt_lang=target_lang)
|
|
log.info(f"model_type={model_type}, src_lang={source_lang}, tgt_lang={target_lang}, output={output}")
|
|
|
|
return output
|
|
|
|
def get_tokenizer_model():
|
|
if model_type == "small100":
|
|
from tokenization_small100 import SMALL100Tokenizer
|
|
tokenizer = SMALL100Tokenizer.from_pretrained(MODEL_DIR)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
|
|
else:
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
|
|
|
|
return tokenizer, model
|
|
|
|
def extract_model_type():
|
|
mt = os.environ.get("MODEL_TYPE", "")
|
|
log.info(f"model_type_input={mt}")
|
|
|
|
model = mt.lower()
|
|
if model not in MODEL_TYPE:
|
|
log.error(f"model_type {model} is not supported")
|
|
os._exit(1)
|
|
|
|
return model
|
|
|
|
def warm_up():
|
|
log.info("warming up...")
|
|
|
|
warmup_test = translator_helper("今天的天气非常好")
|
|
log.info(f"warm up completed! model_type={model_type}, response={warmup_test}")
|
|
|
|
return warmup_test
|
|
|
|
@app.get("/v1/get_status")
|
|
async def get_status():
|
|
ret = {
|
|
"data": {
|
|
"status": status
|
|
}
|
|
}
|
|
return ret
|
|
|
|
@app.post("/v1/translate")
|
|
async def translate(
|
|
payload: List[TranslateRequest],
|
|
):
|
|
if not payload:
|
|
return PlainTextResponse(text="Information missing", status_code=400)
|
|
results = []
|
|
texts = []
|
|
for trans_request in payload:
|
|
translations = []
|
|
texts.append(trans_request.Text)
|
|
|
|
outputs = translator_helper(texts)
|
|
for i in range(0, len(texts)):
|
|
translations.append({
|
|
"origin_text": texts[i],
|
|
"translated": outputs[i]['translation_text']
|
|
})
|
|
|
|
results.append({
|
|
"translations": translations
|
|
})
|
|
return results
|
|
|
|
if __name__ == '__main__':
|
|
uvicorn.run("fastapi_translate:app", host="0.0.0.0", port=80, workers=1, access_log=False)
|
|
|
|
|