translation demo for iluvatar
This commit is contained in:
139
fastapi_translate.py
Normal file
139
fastapi_translate.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import os
|
||||
from fastapi import FastAPI, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from typing import List, Any
|
||||
|
||||
import uvicorn
|
||||
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
import logger
|
||||
log = logger.get_logger(__file__)
|
||||
|
||||
app = FastAPI()
|
||||
status = "Running"
|
||||
translator = None
|
||||
device = None
|
||||
model_type = None
|
||||
|
||||
MODEL_TYPE = ("nllb200", "small100", "mbart", "opus_mt")
|
||||
MODEL_DIR = "/workspace/model"
|
||||
|
||||
class TranslateRequest(BaseModel):
|
||||
Text: str
|
||||
|
||||
@app.on_event("startup")
|
||||
def load_model():
|
||||
log.info("loading model")
|
||||
global status, translator, device, model_type
|
||||
|
||||
model_type = extract_model_type()
|
||||
log.info(f"model_type={model_type}")
|
||||
|
||||
fetch_model()
|
||||
|
||||
tokenizer, model = get_tokenizer_model()
|
||||
#log.info(f"tokenizer={tokenizer}, model={model}")
|
||||
model = model.to("cuda")
|
||||
|
||||
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, device="cuda", use_cache=True)
|
||||
warm_up()
|
||||
|
||||
status = "Success"
|
||||
log.info("model loaded successfully")
|
||||
|
||||
def fetch_model():
|
||||
mn = os.environ.get("MODEL_NAME", "")
|
||||
log.info(f"model_name={mn}")
|
||||
|
||||
os.makedirs(os.path.dirname(MODEL_DIR), exist_ok=True)
|
||||
snapshot_download(mn, local_dir=MODEL_DIR)
|
||||
|
||||
def translator_helper(text):
|
||||
source_lang = "zh"
|
||||
target_lang = "en"
|
||||
|
||||
if model_type == "nllb200":
|
||||
source_lang = "zho_Hans"
|
||||
target_lang = "eng_Latn"
|
||||
if model_type == "mbart":
|
||||
source_lang = "zh_CN"
|
||||
target_lang = "en_XX"
|
||||
if model_type == "opus_mt":
|
||||
source_lang = "eng"
|
||||
target_lang = "zho"
|
||||
|
||||
output = translator(text, src_lang=source_lang, tgt_lang=target_lang)
|
||||
log.info(f"model_type={model_type}, src_lang={source_lang}, tgt_lang={target_lang}, output={output}")
|
||||
|
||||
return output
|
||||
|
||||
def get_tokenizer_model():
|
||||
if model_type == "small100":
|
||||
from tokenization_small100 import SMALL100Tokenizer
|
||||
tokenizer = SMALL100Tokenizer.from_pretrained(MODEL_DIR)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
|
||||
|
||||
return tokenizer, model
|
||||
|
||||
def extract_model_type():
|
||||
mt = os.environ.get("MODEL_TYPE", "")
|
||||
log.info(f"model_type_input={mt}")
|
||||
|
||||
model = mt.lower()
|
||||
if model not in MODEL_TYPE:
|
||||
log.error(f"model_type {model} is not supported")
|
||||
os._exit(1)
|
||||
|
||||
return model
|
||||
|
||||
def warm_up():
|
||||
log.info("warming up...")
|
||||
|
||||
warmup_test = translator_helper("今天的天气非常好")
|
||||
log.info(f"warm up completed! model_type={model_type}, response={warmup_test}")
|
||||
|
||||
return warmup_test
|
||||
|
||||
@app.get("/v1/get_status")
|
||||
async def get_status():
|
||||
ret = {
|
||||
"data": {
|
||||
"status": status
|
||||
}
|
||||
}
|
||||
return ret
|
||||
|
||||
@app.post("/v1/translate")
|
||||
async def translate(
|
||||
payload: List[TranslateRequest],
|
||||
):
|
||||
if not payload:
|
||||
return PlainTextResponse(text="Information missing", status_code=400)
|
||||
results = []
|
||||
texts = []
|
||||
for trans_request in payload:
|
||||
translations = []
|
||||
texts.append(trans_request.Text)
|
||||
|
||||
outputs = translator_helper(texts)
|
||||
for i in range(0, len(texts)):
|
||||
translations.append({
|
||||
"origin_text": texts[i],
|
||||
"translated": outputs[i]['translation_text']
|
||||
})
|
||||
|
||||
results.append({
|
||||
"translations": translations
|
||||
})
|
||||
return results
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run("fastapi_translate:app", host="0.0.0.0", port=80, workers=1, access_log=False)
|
||||
|
||||
Reference in New Issue
Block a user