Files
enginex-c_series-translation/fastapi_translate.py
2025-09-18 16:19:25 +08:00

140 lines
3.6 KiB
Python

import os
from fastapi import FastAPI, Query
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from typing import List, Any
import uvicorn
from modelscope import snapshot_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import logger
log = logger.get_logger(__file__)
app = FastAPI()
status = "Running"
translator = None
device = None
model_type = None
MODEL_TYPE = ("nllb200", "small100", "mbart", "opus_mt")
MODEL_DIR = "/workspace/model"
class TranslateRequest(BaseModel):
Text: str
@app.on_event("startup")
def load_model():
log.info("loading model")
global status, translator, device, model_type
model_type = extract_model_type()
log.info(f"model_type={model_type}")
fetch_model()
tokenizer, model = get_tokenizer_model()
#log.info(f"tokenizer={tokenizer}, model={model}")
model = model.to("cuda")
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, device="cuda", use_cache=True)
warm_up()
status = "Success"
log.info("model loaded successfully")
def fetch_model():
mn = os.environ.get("MODEL_NAME", "")
log.info(f"model_name={mn}")
os.makedirs(os.path.dirname(MODEL_DIR), exist_ok=True)
snapshot_download(mn, local_dir=MODEL_DIR)
def translator_helper(text):
source_lang = "zh"
target_lang = "en"
if model_type == "nllb200":
source_lang = "zho_Hans"
target_lang = "eng_Latn"
if model_type == "mbart":
source_lang = "zh_CN"
target_lang = "en_XX"
if model_type == "opus_mt":
source_lang = "eng"
target_lang = "zho"
output = translator(text, src_lang=source_lang, tgt_lang=target_lang)
log.info(f"model_type={model_type}, src_lang={source_lang}, tgt_lang={target_lang}, output={output}")
return output
def get_tokenizer_model():
if model_type == "small100":
from tokenization_small100 import SMALL100Tokenizer
tokenizer = SMALL100Tokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
else:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
return tokenizer, model
def extract_model_type():
mt = os.environ.get("MODEL_TYPE", "")
log.info(f"model_type_input={mt}")
model = mt.lower()
if model not in MODEL_TYPE:
log.error(f"model_type {model} is not supported")
os._exit(1)
return model
def warm_up():
log.info("warming up...")
warmup_test = translator_helper("今天的天气非常好")
log.info(f"warm up completed! model_type={model_type}, response={warmup_test}")
return warmup_test
@app.get("/v1/get_status")
async def get_status():
ret = {
"data": {
"status": status
}
}
return ret
@app.post("/v1/translate")
async def translate(
payload: List[TranslateRequest],
):
if not payload:
return PlainTextResponse(text="Information missing", status_code=400)
results = []
texts = []
for trans_request in payload:
translations = []
texts.append(trans_request.Text)
outputs = translator_helper(texts)
for i in range(0, len(texts)):
translations.append({
"origin_text": texts[i],
"translated": outputs[i]['translation_text']
})
results.append({
"translations": translations
})
return results
if __name__ == '__main__':
uvicorn.run("fastapi_translate:app", host="0.0.0.0", port=80, workers=1, access_log=False)