import os import io import time import base64 import shutil from typing import Any, Dict, List, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel from starlette.responses import JSONResponse from PIL import Image import torch from modelscope import AutoModel, AutoTokenizer # -------- Configuration -------- MODEL_DIR = os.environ.get("DEESEEK_MODEL_DIR", "/model") MODEL_PREFERRED_DTYPE = os.environ.get("DEESEEK_DTYPE", "bfloat16") # or float16/float32 # -------- FastAPI app -------- app = FastAPI(title="DeepSeek-OCR vllm-format wrapper") class GenerateRequest(BaseModel): messages: List[Dict[str, Any]] # optional params mapping to your OCR infer options base_size: Optional[int] = 1024 image_size: Optional[int] = 640 crop_mode: Optional[bool] = True save_results: Optional[bool] = True test_compress: Optional[bool] = True def _decode_data_uri_image(data_uri: str) -> Image.Image: """Decode a data:image/...;base64,xxxx URI into PIL.Image.""" if not data_uri.startswith("data:"): raise ValueError("Not a data URI") header, b64 = data_uri.split(",", 1) decoded = base64.b64decode(b64) return Image.open(io.BytesIO(decoded)).convert("RGB") # Load tokenizer + model print("Loading tokenizer and model...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) except Exception as e: print(f"Failed to load tokenizer from {MODEL_DIR}: {e}") raise try: model = AutoModel.from_pretrained(MODEL_DIR, trust_remote_code=True, use_safetensors=True) except Exception as e: print(f"Failed to load model from {MODEL_DIR}: {e}") raise # move to device and set dtype if possible try: model = model.eval().cuda().to(torch.bfloat16) except Exception as e: print(f"Warning while preparing model device/dtype: {e}") print("Model loaded and prepared.") # -------- Routes -------- @app.get("/health") def health_check(): return JSONResponse(status_code=200, content={"status": "ok"}) @app.post("/generate") def generate(req: GenerateRequest): messages = req.messages if not messages or not isinstance(messages, list): raise HTTPException(status_code=400, detail="messages must be a non-empty list") # Convert vllm-style messages -> conversation format conversation = [] for m in messages: role = m.get("role", "user") raw_content = m.get("content", []) content_list = [] for c in raw_content: ctype = c.get("type") if ctype == "image_url": url = None if isinstance(c.get("image_url"), dict): url = c["image_url"].get("url") else: url = c.get("image_url") content_list.append({"type": "image", "image": url}) elif ctype == "text": content_list.append({"type": "text", "text": c.get("text", "")}) else: content_list.append(c) conversation.append({"role": role, "content": content_list}) # collect images (data URIs will be decoded into temporary files) images_for_infer = [] temp_files = [] try: for msg in conversation: for c in msg["content"]: if c.get("type") == "image": img_ref = c.get("image") if isinstance(img_ref, str) and img_ref.startswith("data:"): try: pil = _decode_data_uri_image(img_ref) except Exception as e: raise HTTPException(status_code=400, detail=f"failed to decode data URI image: {e}") # save to temp file so model.infer can read path if it expects a path tpath = os.path.join("/tmp", f"deepproc_{int(time.time()*1000)}.png") pil.save(tpath) temp_files.append(tpath) images_for_infer.append(tpath) else: # assume it's a path or URL acceptable to model.infer images_for_infer.append(img_ref) # Prepare prompt: for DeepSeek-OCR we typically pass something like '\nFree OCR.' as in your example. # Allow overriding by looking for a text content in the messages. # prompt_text = None # for msg in conversation: # for c in msg["content"]: # if c.get("type") == "text" and c.get("text"): # prompt_text = c.get("text") # break # if prompt_text: # break # if not prompt_text: prompt_text = "\nFree OCR." # default prompt # call model.infer; support single image or batch (here we will pass the first image if multiple) if len(images_for_infer) == 0: raise HTTPException(status_code=400, detail="no images provided") # Use the first image by default; you can extend to batch inference. image_input = images_for_infer[0] output_path = "./output/" if not hasattr(req, 'output_path') else getattr(req, 'output_path') os.makedirs(output_path, exist_ok=True) # start_time = time.time() # The example uses: model.infer(tokenizer, prompt, image_file=image_file, output_path=..., base_size=..., ...) try: res = model.infer( tokenizer, prompt=prompt_text, image_file=image_input, output_path="./output/", #if not req.save_results else os.path.join(MODEL_DIR, "infer_out"), base_size=req.base_size, image_size=req.image_size, crop_mode=req.crop_mode, save_results=req.save_results, test_compress=req.test_compress, ) except TypeError: # fallback: try without named args if certain impls expect positional res = model.infer(tokenizer, prompt_text, image_input) # end_time = time.time() # elapsed = end_time - start_time print ("res:\n", res) # print (elapsed) result_mmd_path = os.path.join(output_path, "result.mmd") try: if os.path.isfile(result_mmd_path): with open(result_mmd_path, "r", encoding="utf-8") as f: file_content = f.read().strip() if file_content: ocr_text = file_content except Exception as e: # log but don't fail; we'll fall back to parsing the model response try: logger.warning(f"Failed to read {result_mmd_path}: {e}") except Exception: pass # prepare response content; `res` may be a dict or string depending on model impl # ocr_text = None # if isinstance(res, dict): # # try common keys # ocr_text = res.get("text") or res.get("result") or res.get("ocr_text") # elif isinstance(res, (list, tuple)): # # try first element # ocr_text = res[0] if len(res) > 0 else None # else: # ocr_text = str(res) # if ocr_text is None: # ocr_text = str(res) response = { "id": "chatcmpl-deepseek", "object": "chat.completion", "created": int(time.time()), "model": os.path.basename(MODEL_DIR), "choices": [ { "index": 0, "message": { "role": "assistant", "content": ocr_text, }, "finish_reason": "stop", } ] } return JSONResponse(response) finally: # cleanup temp files we created for t in temp_files: try: os.remove(t) except Exception: pass if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=80)