Files
enginex-bi_series-translation/fastapi_translate.py

140 lines
3.6 KiB
Python
Raw Normal View History

2025-08-26 19:04:22 +08:00
import os
from fastapi import FastAPI, Query
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from typing import List, Any
import uvicorn
from modelscope import snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import logger
log = logger.get_logger(__file__)
app = FastAPI()
status = "Running"
translator = None
device = None
model_type = None
MODEL_TYPE = ("nllb200", "small100", "mbart", "opus_mt")
MODEL_DIR = "/workspace/model"
class TranslateRequest(BaseModel):
Text: str
@app.on_event("startup")
def load_model():
log.info("loading model")
global status, translator, device, model_type
model_type = extract_model_type()
log.info(f"model_type={model_type}")
fetch_model()
tokenizer, model = get_tokenizer_model()
#log.info(f"tokenizer={tokenizer}, model={model}")
model = model.to("cuda")
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, device="cuda", use_cache=True)
warm_up()
status = "Success"
log.info("model loaded successfully")
def fetch_model():
mn = os.environ.get("MODEL_NAME", "")
log.info(f"model_name={mn}")
os.makedirs(os.path.dirname(MODEL_DIR), exist_ok=True)
snapshot_download(mn, local_dir=MODEL_DIR)
def translator_helper(text):
source_lang = "zh"
target_lang = "en"
if model_type == "nllb200":
source_lang = "zho_Hans"
target_lang = "eng_Latn"
if model_type == "mbart":
source_lang = "zh_CN"
target_lang = "en_XX"
if model_type == "opus_mt":
source_lang = "eng"
target_lang = "zho"
output = translator(text, src_lang=source_lang, tgt_lang=target_lang)
log.info(f"model_type={model_type}, src_lang={source_lang}, tgt_lang={target_lang}, output={output}")
return output
def get_tokenizer_model():
if model_type == "small100":
from tokenization_small100 import SMALL100Tokenizer
tokenizer = SMALL100Tokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
else:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
return tokenizer, model
def extract_model_type():
mt = os.environ.get("MODEL_TYPE", "")
log.info(f"model_type_input={mt}")
model = mt.lower()
if model not in MODEL_TYPE:
log.error(f"model_type {model} is not supported")
os._exit(1)
return model
def warm_up():
log.info("warming up...")
warmup_test = translator_helper("今天的天气非常好")
log.info(f"warm up completed! model_type={model_type}, response={warmup_test}")
return warmup_test
@app.get("/v1/get_status")
async def get_status():
ret = {
"data": {
"status": status
}
}
return ret
@app.post("/v1/translate")
async def translate(
payload: List[TranslateRequest],
):
if not payload:
return PlainTextResponse(text="Information missing", status_code=400)
results = []
texts = []
for trans_request in payload:
translations = []
texts.append(trans_request.Text)
outputs = translator_helper(texts)
for i in range(0, len(texts)):
translations.append({
"origin_text": texts[i],
"translated": outputs[i]['translation_text']
})
results.append({
"translations": translations
})
return results
if __name__ == '__main__':
uvicorn.run("fastapi_translate:app", host="0.0.0.0", port=80, workers=1, access_log=False)