From da8ac28a97a45a80c32be015bf1d5569db930e90 Mon Sep 17 00:00:00 2001 From: luopingyi Date: Sat, 1 Nov 2025 11:52:15 +0800 Subject: [PATCH] init --- Dockerfile | 11 + README.md | 53 +- app.py | 224 +++++++++ modeling_deepseekocr.py | 1037 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 1324 insertions(+), 1 deletion(-) create mode 100644 Dockerfile create mode 100644 app.py create mode 100644 modeling_deepseekocr.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..54a422a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +FROM git.modelhub.org.cn:9443/enginex-ascend/vllm-ascend:v0.11.0rc0 + +WORKDIR /app + +RUN pip install transformers==4.46.3 einops addict easydict modelscope uvicorn fastapi + +COPY app.py . + +ENTRYPOINT [] + +CMD ["python", "-m", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"] diff --git a/README.md b/README.md index 00c6b41..44c8c05 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,53 @@ -# enginex-ascend-910-vllm +# enginex-ascend-910-transformer-deepseekOCR + +运行于【昇腾-910】系列算力卡的【视觉多模态】引擎,基于 transformer 引擎进行架构特别适配优化,支持 DeepSeek-OCR最新开源模型 + +## QuickStart + +1、从 modelscope上下载支持 DeepSeek-OCR +```python +modelscope download --model deepseek-ai/DeepSeek-OCR README.md --local_dir ./model +``` +将仓库里的 modeling_deepseekocr.py 复制到模型目录覆盖原本的文件 + +2、使用Dockerfile生成镜像 +从仓库的【软件包】栏目下载基础镜像 git.modelhub.org.cn:9443/enginex-ascend/vllm-ascend:v0.11.0rc0 + +使用 Dockerfile 生成 镜像 +```python +docker build -f Dockerfile -t ascend:deepseek_ocr . +``` + + +3、启动docker +```python +docker run -it --rm \ + -p 10086:80 \ + --name test-ascend-my-1 \ + -v `pwd`:/host \ + -e ASCEND_VISIBLE_DEVICES=1 \ + --device /dev/davinci1:/dev/davinci0 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v ./model:/model \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + --privileged \ + ascend:deepseek_ocr +``` + +4、测试服务 +```python +curl -X POST http://localhost:10086/generate \ + -H "Content-Type: application/json" \ + -d '{ + "model": "qwen3-8b", + "messages": [{"role": "user", "content": "你好"}], + "stream": true + }' +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000..dc6a6a5 --- /dev/null +++ b/app.py @@ -0,0 +1,224 @@ +import os +import io +import time +import base64 +import shutil +from typing import Any, Dict, List, Optional + +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from starlette.responses import JSONResponse +from PIL import Image + +import torch +from modelscope import AutoModel, AutoTokenizer + +# -------- Configuration -------- +MODEL_DIR = os.environ.get("DEESEEK_MODEL_DIR", "/mnt/models") +MODEL_PREFERRED_DTYPE = os.environ.get("DEESEEK_DTYPE", "bfloat16") # or float16/float32 + +# -------- FastAPI app -------- +app = FastAPI(title="DeepSeek-OCR vllm-format wrapper") + +class GenerateRequest(BaseModel): + messages: List[Dict[str, Any]] + # optional params mapping to your OCR infer options + base_size: Optional[int] = 1024 + image_size: Optional[int] = 640 + crop_mode: Optional[bool] = True + save_results: Optional[bool] = True + test_compress: Optional[bool] = True + +def _decode_data_uri_image(data_uri: str) -> Image.Image: + """Decode a data:image/...;base64,xxxx URI into PIL.Image.""" + if not data_uri.startswith("data:"): + raise ValueError("Not a data URI") + header, b64 = data_uri.split(",", 1) + decoded = base64.b64decode(b64) + return Image.open(io.BytesIO(decoded)).convert("RGB") + +# Load tokenizer + model +print("Loading tokenizer and model...") +try: + tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) +except Exception as e: + print(f"Failed to load tokenizer from {MODEL_DIR}: {e}") + raise + +try: + model = AutoModel.from_pretrained(MODEL_DIR, trust_remote_code=True, use_safetensors=True) +except Exception as e: + print(f"Failed to load model from {MODEL_DIR}: {e}") + raise + +# move to device and set dtype if possible +try: + model = model.eval().npu().to(torch.bfloat16) +except Exception as e: + print(f"Warning while preparing model device/dtype: {e}") + +print("Model loaded and prepared.") + +# -------- Routes -------- + +@app.get("/health") +def health_check(): + return JSONResponse(status_code=200, content={"status": "ok"}) + + +@app.post("/generate") +def generate(req: GenerateRequest): + messages = req.messages + if not messages or not isinstance(messages, list): + raise HTTPException(status_code=400, detail="messages must be a non-empty list") + + # Convert vllm-style messages -> conversation format + conversation = [] + for m in messages: + role = m.get("role", "user") + raw_content = m.get("content", []) + content_list = [] + for c in raw_content: + ctype = c.get("type") + if ctype == "image_url": + url = None + if isinstance(c.get("image_url"), dict): + url = c["image_url"].get("url") + else: + url = c.get("image_url") + content_list.append({"type": "image", "image": url}) + elif ctype == "text": + content_list.append({"type": "text", "text": c.get("text", "")}) + else: + content_list.append(c) + conversation.append({"role": role, "content": content_list}) + + # collect images (data URIs will be decoded into temporary files) + images_for_infer = [] + temp_files = [] + try: + for msg in conversation: + for c in msg["content"]: + if c.get("type") == "image": + img_ref = c.get("image") + if isinstance(img_ref, str) and img_ref.startswith("data:"): + try: + pil = _decode_data_uri_image(img_ref) + except Exception as e: + raise HTTPException(status_code=400, detail=f"failed to decode data URI image: {e}") + # save to temp file so model.infer can read path if it expects a path + tpath = os.path.join("/tmp", f"deepproc_{int(time.time()*1000)}.png") + pil.save(tpath) + temp_files.append(tpath) + images_for_infer.append(tpath) + else: + # assume it's a path or URL acceptable to model.infer + images_for_infer.append(img_ref) + + # Prepare prompt: for DeepSeek-OCR we typically pass something like '\nFree OCR.' as in your example. + # Allow overriding by looking for a text content in the messages. + # prompt_text = None + # for msg in conversation: + # for c in msg["content"]: + # if c.get("type") == "text" and c.get("text"): + # prompt_text = c.get("text") + # break + # if prompt_text: + # break + # if not prompt_text: + prompt_text = "\nFree OCR." # default prompt + + # call model.infer; support single image or batch (here we will pass the first image if multiple) + if len(images_for_infer) == 0: + raise HTTPException(status_code=400, detail="no images provided") + + # Use the first image by default; you can extend to batch inference. + image_input = images_for_infer[0] + + output_path = "./output/" if not hasattr(req, 'output_path') else getattr(req, 'output_path') + os.makedirs(output_path, exist_ok=True) + + # start_time = time.time() + # The example uses: model.infer(tokenizer, prompt, image_file=image_file, output_path=..., base_size=..., ...) + try: + res = model.infer( + tokenizer, + prompt=prompt_text, + image_file=image_input, + output_path="./output/", #if not req.save_results else os.path.join(MODEL_DIR, "infer_out"), + base_size=req.base_size, + image_size=req.image_size, + crop_mode=req.crop_mode, + save_results=req.save_results, + test_compress=req.test_compress, + ) + except TypeError: + # fallback: try without named args if certain impls expect positional + res = model.infer(tokenizer, prompt_text, image_input) + + # end_time = time.time() + # elapsed = end_time - start_time + + print ("res:\n", res) + # print (elapsed) + + result_mmd_path = os.path.join(output_path, "result.mmd") + + try: + if os.path.isfile(result_mmd_path): + with open(result_mmd_path, "r", encoding="utf-8") as f: + file_content = f.read().strip() + if file_content: + ocr_text = file_content + except Exception as e: + # log but don't fail; we'll fall back to parsing the model response + try: + logger.warning(f"Failed to read {result_mmd_path}: {e}") + except Exception: + pass + + # prepare response content; `res` may be a dict or string depending on model impl + # ocr_text = None + # if isinstance(res, dict): + # # try common keys + # ocr_text = res.get("text") or res.get("result") or res.get("ocr_text") + # elif isinstance(res, (list, tuple)): + # # try first element + # ocr_text = res[0] if len(res) > 0 else None + # else: + # ocr_text = str(res) + + # if ocr_text is None: + # ocr_text = str(res) + + response = { + "id": "chatcmpl-deepseek", + "object": "chat.completion", + "created": int(time.time()), + "model": os.path.basename(MODEL_DIR), + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": ocr_text, + }, + "finish_reason": "stop", + } + ] + } + + return JSONResponse(response) + + finally: + # cleanup temp files we created + for t in temp_files: + try: + os.remove(t) + except Exception: + pass + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=80) + diff --git a/modeling_deepseekocr.py b/modeling_deepseekocr.py new file mode 100644 index 0000000..31d0159 --- /dev/null +++ b/modeling_deepseekocr.py @@ -0,0 +1,1037 @@ +from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM +from .configuration_deepseek_v2 import DeepseekV2Config +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from typing import List, Optional, Tuple, Union +from transformers.cache_utils import Cache +import requests +from PIL import Image, ImageOps, ImageDraw, ImageFont +from io import BytesIO +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import os +from .deepencoder import build_sam_vit_b, build_clip_l, MlpProjector +from addict import Dict +from transformers import TextStreamer +from .conversation import get_conv_template +from abc import ABC +import math +import re +from tqdm import tqdm +import numpy as np +import time + + +def load_image(image_path): + + try: + image = Image.open(image_path) + + corrected_image = ImageOps.exif_transpose(image) + + return corrected_image + + except Exception as e: + print(f"error: {e}") + try: + return Image.open(image_path) + except: + return None + + +def re_match(text): + pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' + matches = re.findall(pattern, text, re.DOTALL) + + # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' + # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) + + mathes_image = [] + mathes_other = [] + for a_match in matches: + if '<|ref|>image<|/ref|>' in a_match[0]: + mathes_image.append(a_match[0]) + else: + mathes_other.append(a_match[0]) + return matches, mathes_image, mathes_other + + +def extract_coordinates_and_label(ref_text, image_width, image_height): + + try: + label_type = ref_text[1] + cor_list = eval(ref_text[2]) + except Exception as e: + print(e) + return None + + return (label_type, cor_list) + + +def draw_bounding_boxes(image, refs, ouput_path): + + image_width, image_height = image.size + + img_draw = image.copy() + draw = ImageDraw.Draw(img_draw) + + overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) + draw2 = ImageDraw.Draw(overlay) + + # try: + # except IOError: + # try: + # font = ImageFont.truetype("DejaVuSans.ttf", 20) + # except IOError: + font = ImageFont.load_default() + + img_idx = 0 + + for i, ref in enumerate(refs): + try: + result = extract_coordinates_and_label(ref, image_width, image_height) + if result: + label_type, points_list = result + + color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) + + color_a = color + (20, ) + for points in points_list: + x1, y1, x2, y2 = points + + x1 = int(x1 / 999 * image_width) + y1 = int(y1 / 999 * image_height) + + x2 = int(x2 / 999 * image_width) + y2 = int(y2 / 999 * image_height) + + if label_type == 'image': + try: + cropped = image.crop((x1, y1, x2, y2)) + cropped.save(f"{ouput_path}/images/{img_idx}.jpg") + except Exception as e: + print(e) + pass + img_idx += 1 + + try: + if label_type == 'title': + draw.rectangle([x1, y1, x2, y2], outline=color, width=4) + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) + else: + draw.rectangle([x1, y1, x2, y2], outline=color, width=2) + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) + text_x = x1 + text_y = max(0, y1 - 15) + + + text_bbox = draw.textbbox((0, 0), label_type, font=font) + text_width = text_bbox[2] - text_bbox[0] + text_height = text_bbox[3] - text_bbox[1] + draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], + fill=(255, 255, 255, 30)) + + draw.text((text_x, text_y), label_type, font=font, fill=color) + except: + pass + except: + continue + img_draw.paste(overlay, (0, 0), overlay) + return img_draw + + +def process_image_with_refs(image, ref_texts, output_path): + + result_image = draw_bounding_boxes(image, ref_texts, output_path) + + return result_image + + + + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') + return best_ratio + + +def dynamic_preprocess(image, min_num=2, max_num=9, image_size=640, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + # print(target_ratios) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # print(target_aspect_ratio) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images, target_aspect_ratio + + + +def normalize_transform(mean, std): + if mean is None and std is None: + transform = None + elif mean is None and std is not None: + mean = [0.] * len(std) + transform = transforms.Normalize(mean=mean, std=std) + elif mean is not None and std is None: + std = [1.] * len(mean) + transform = transforms.Normalize(mean=mean, std=std) + else: + transform = transforms.Normalize(mean=mean, std=std) + + return transform + + + +def format_messages( + conversations: List[Dict[str, str]], + sft_format: str = "deepseek", + system_prompt: str = "", +): + """ + Applies the SFT template to conversation. + + Args: + conversations (List[Dict]): A List of messages. + sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". + + Returns: + sft_prompt (str): The formatted text. + """ + + conv = get_conv_template(sft_format) + conv.set_system_message(system_prompt) + for message in conversations: + conv.append_message(message["role"], message["content"].strip()) + sft_prompt = conv.get_prompt().strip() + + return sft_prompt + + +def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): + t = tokenizer.encode(text, add_special_tokens=False) + bos_id = 0 + eos_id = 1 + if bos: + t = [bos_id] + t + if eos: + t = t + [eos_id] + + return t + +def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: + """ + + Args: + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : + [ + { + "role": "User", + "content": "\nExtract all information from this image and convert them into markdown format.", + "images": ["./examples/table_datasets.png"] + }, + {"role": "Assistant", "content": ""}, + ] + + Returns: + pil_images (List[PIL.Image.Image]): the list of PIL images. + + """ + + pil_images = [] + + for message in conversations: + if "images" not in message: + continue + + for image_path in message["images"]: + # print('----------------') + # print(image_path) + # print('----------------') + # exit() + + # pil_img = Image.open(image_path) + pil_img = load_image(image_path) + pil_img = pil_img.convert("RGB") + pil_images.append(pil_img) + + return pil_images + + +class BaseTransform(ABC): + + def set_rng(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs) -> torch.Tensor: + pass + + @property + def default_shape(self): + raise NotImplementedError + + +class BasicImageTransform(BaseTransform): + def __init__( + self, + mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), + std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), + normalize: bool = True + ): + self.mean = mean + self.std = std + + transform_pipelines = [ + transforms.ToTensor() + ] + + normalize = normalize_transform(mean, std) if normalize else nn.Identity() + if normalize is not None: + transform_pipelines.append(normalize) + + self.transform = transforms.Compose(transform_pipelines) + + def __call__(self, x): + x = self.transform(x) + return x + +class NoEOSTextStreamer(TextStreamer): + def on_finalized_text(self, text: str, stream_end: bool = False): + + eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) + text = text.replace(eos_text, "\n") + print(text, flush=True, end="") + + +class DeepseekOCRConfig(DeepseekV2Config): + model_type = "DeepseekOCR" + +class DeepseekOCRModel(DeepseekV2Model): + config_class = DeepseekOCRConfig + + def __init__(self, config: DeepseekV2Config): + super(DeepseekOCRModel, self).__init__(config) + + self.sam_model = build_sam_vit_b() + self.vision_model = build_clip_l() + # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) + n_embed = 1280 + self.projector = MlpProjector(Dict(projector_type="linear", input_dim=2048, n_embed=n_embed)) + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) + self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) + + + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + images_seq_mask: Optional[torch.FloatTensor] = None, + images_spatial_crop: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + + + + if inputs_embeds is None: + # inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.get_input_embeddings()(input_ids) + + + + sam_model = getattr(self, 'sam_model', None) + # sam_model = self.sam_model + vision_model = getattr(self, 'vision_model', None) + + + + if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: + + idx = 0 + + # sam_model = torch.jit.script(sam_model) + + # start_time = time.time() + for image, crop_shape in zip(images, images_spatial_crop): + images_in_this_batch = [] + + patches = image[0] + image_ori = image[1] + + with torch.no_grad(): + # with torch.inference_mode(): + + if torch.sum(patches).item() != 0: + # P, C, H, W = patches.shape + crop_flag = 1 + local_features_1 = sam_model(patches) + + local_features_2 = vision_model(patches, local_features_1) + # vit_time = time.time() + local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) + local_features = self.projector(local_features) + + + global_features_1 = sam_model(image_ori) + global_features_2 = vision_model(image_ori, global_features_1) + global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) + global_features = self.projector(global_features) + + print('=====================') + print('BASE: ', global_features.shape) + print('PATCHES: ', local_features.shape) + print('=====================') + + _, hw, n_dim = global_features.shape + h = w = int(hw ** 0.5) + + _2, hw2, n_dim2 = local_features.shape + h2 = w2 = int(hw2 ** 0.5) + + width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] + + global_features = global_features.view(h, w, n_dim) + + global_features = torch.cat( + [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 + ) + + global_features = global_features.view(-1, n_dim) + + + local_features = local_features.view(height_crop_num, width_crop_num, h2, w2, n_dim2).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) + local_features = torch.cat( + [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 + ) + local_features = local_features.view(-1, n_dim2) + + global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) + + # end_time = time.time() + + # print('sam: ', sam_time - start_time) + # print('vit: ', vit_time - sam_time) + # print('all: ', end_time - start_time) + + # exit() + + else: + global_features_1 = sam_model(image_ori) + global_features_2 = vision_model(image_ori, global_features_1) + global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) + global_features = self.projector(global_features) + print('=====================') + print('BASE: ', global_features.shape) + print('NO PATCHES') + print('=====================') + _, hw, n_dim = global_features.shape + h = w = int(hw ** 0.5) + + + global_features = global_features.view(h, w, n_dim) + + global_features = torch.cat( + [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 + ) + + global_features = global_features.view(-1, n_dim) + + global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) + + images_in_this_batch.append(global_local_features) + + + # print(inputs_embeds.shape) + + if images_in_this_batch: + images_in_this_batch = torch.cat(images_in_this_batch, dim=0) + # exit() + + inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).npu(), images_in_this_batch) + + idx += 1 + + + return super(DeepseekOCRModel, self).forward( + input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + +class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM): + + config_class = DeepseekOCRConfig + # supports_gradient_checkpointing = True + + def __init__(self, config): + super(DeepseekV2ForCausalLM, self).__init__(config) + self.model = DeepseekOCRModel(config) + + self.vocab_size = config.vocab_size + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + images_seq_mask: Optional[torch.FloatTensor] = None, + images_spatial_crop: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + + outputs = self.model( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + images=images, + images_seq_mask = images_seq_mask, + images_spatial_crop = images_spatial_crop, + return_dict=return_dict + + ) + + + + # print(transformer_outputs) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + # logits + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if self.generation_config.cache_implementation == "static": + # # generation with static cache + # cache_position = kwargs.get("cache_position", None) + # if cache_position is None: + # past_length = 0 + # else: + # past_length = cache_position[-1] + 1 + # input_ids = input_ids[:, past_length:] + # position_ids = position_ids[:, past_length:] + + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + "images_seq_mask": kwargs.get("images_seq_mask", None), + "images_spatial_crop": kwargs.get("images_spatial_crop", None), + } + ) + return model_inputs + + + def disable_torch_init(self): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + + + def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): + self.disable_torch_init() + + os.makedirs(output_path, exist_ok=True) + os.makedirs(f'{output_path}/images', exist_ok=True) + + if prompt and image_file: + conversation = [ + { + "role": "<|User|>", + # "content": "\n<|grounding|>Given the layout of the image. ", + "content": f'{prompt}', + # "content": "君不见黄河之水天上来的下一句是什么?", + # "content": "\nFree OCR. ", + # "content": "\nParse the figure. ", + # "content": "\nExtract the text in the image. ", + "images": [f'{image_file}'], + }, + {"role": "<|Assistant|>", "content": ""}, + ] + + elif prompt: + conversation = [ + { + "role": "<|User|>", + # "content": "\n<|grounding|>Given the layout of the image. ", + "content": f'{prompt}', + # "content": "君不见黄河之水天上来的下一句是什么?", + # "content": "\nFree OCR. ", + # "content": "\nParse the figure. ", + # "content": "\nExtract the text in the image. ", + # "images": [f'{image_file}'], + }, + {"role": "<|Assistant|>", "content": ""}, + ] + else: + assert False, f'prompt is none!' + + prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') + + patch_size = 16 + downsample_ratio = 4 + images = load_pil_images(conversation) + + valid_img_tokens = 0 + ratio = 1 + + image_draw = images[0].copy() + + w,h = image_draw.size + # print(w, h) + ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) + + + image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) + images_seq_mask = [] + + image_token = '' + image_token_id = 128815 + text_splits = prompt.split(image_token) + + images_list, images_crop_list, images_seq_mask = [], [], [] + tokenized_str = [] + images_spatial_crop = [] + for text_sep, image in zip(text_splits, images): + + tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + if crop_mode: + + if image.size[0] <= 640 and image.size[1] <= 640: + crop_ratio = [1, 1] + + else: + if crop_mode: + # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) + images_crop_raw, crop_ratio = dynamic_preprocess(image) + else: + # best_width, best_height = self.image_size, self.image_size + crop_ratio = [1, 1] + + """process the global view""" + # image = image.resize((base_size, base_size)) + global_view = ImageOps.pad(image, (base_size, base_size), + color=tuple(int(x * 255) for x in image_transform.mean)) + + if base_size == 1024: + valid_img_tokens += int(256 * ratio) + elif base_size == 1280: + valid_img_tokens += int(400 * ratio) + # elif base_size == 640: + # valid_img_tokens += int(100 * ratio) + + + + + + images_list.append(image_transform(global_view).to(torch.bfloat16)) + + # global_view_tensor = image_transform(global_view).to(torch.bfloat16) + + width_crop_num, height_crop_num = crop_ratio + + images_spatial_crop.append([width_crop_num, height_crop_num]) + + + if width_crop_num > 1 or height_crop_num > 1: + """process the local views""" + + for i in range(len(images_crop_raw)): + images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) + + if image_size == 640: + valid_img_tokens += len(images_crop_list) * 100 + + num_queries = math.ceil((image_size // patch_size) / downsample_ratio) + num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) + + + + """add image tokens""" + + + + tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base + tokenized_image += [image_token_id] + if width_crop_num > 1 or height_crop_num > 1: + tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( + num_queries * height_crop_num) + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + # num_image_tokens.append(len(tokenized_image)) + + else: + # best_width, best_height = self.image_size, self.image_size + # print(image.size, (best_width, best_height)) # check the select_best_resolutions func + + """process the global view""" + if image_size <= 640: + print('directly resize') + image = image.resize((image_size, image_size)) + # else: + global_view = ImageOps.pad(image, (image_size, image_size), + color=tuple(int(x * 255) for x in image_transform.mean)) + images_list.append(image_transform(global_view).to(torch.bfloat16)) + + if base_size == 1024: + valid_img_tokens += int(256 * ratio) + elif base_size == 1280: + valid_img_tokens += int(400 * ratio) + elif base_size == 640: + valid_img_tokens += int(100 * 1) + elif base_size == 512: + valid_img_tokens += int(64 * 1) + + width_crop_num, height_crop_num = 1, 1 + + images_spatial_crop.append([width_crop_num, height_crop_num]) + + + """add image tokens""" + num_queries = math.ceil((image_size // patch_size) / downsample_ratio) + + tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries + tokenized_image += [image_token_id] + # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( + # num_queries * height_crop_num) + tokenized_str += tokenized_image + images_seq_mask += [True] * len(tokenized_image) + # num_image_tokens.append(len(tokenized_image)) + + + """process the last text split""" + tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) + tokenized_str += tokenized_sep + images_seq_mask += [False] * len(tokenized_sep) + + """add the bos tokens""" + bos_id = 0 + tokenized_str = [bos_id] + tokenized_str + images_seq_mask = [False] + images_seq_mask + + + + input_ids = torch.LongTensor(tokenized_str) + + + + + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) + + + if len(images_list) == 0: + images_ori = torch.zeros((1, 3, image_size, image_size)) + images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) + images_crop = torch.zeros((1, 3, base_size, base_size)) + + else: + images_ori = torch.stack(images_list, dim=0) + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) + if images_crop_list: + images_crop = torch.stack(images_crop_list, dim=0) + else: + images_crop = torch.zeros((1, 3, base_size, base_size)) + + + + if not eval_mode: + streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) + with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.no_grad(): + output_ids = self.generate( + input_ids.unsqueeze(0).npu(), + images=[(images_crop.npu(), images_ori.npu())], + images_seq_mask = images_seq_mask.unsqueeze(0).npu(), + images_spatial_crop = images_spatial_crop, + # do_sample=False, + # num_beams = 1, + temperature=0.0, + eos_token_id=tokenizer.eos_token_id, + streamer=streamer, + max_new_tokens=8192, + no_repeat_ngram_size = 20, + use_cache = True + ) + + else: + with torch.autocast("cuda", dtype=torch.bfloat16): + with torch.no_grad(): + output_ids = self.generate( + input_ids.unsqueeze(0).npu(), + images=[(images_crop.npu(), images_ori.npu())], + images_seq_mask = images_seq_mask.unsqueeze(0).npu(), + images_spatial_crop = images_spatial_crop, + # do_sample=False, + # num_beams = 1, + temperature=0.0, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=8192, + no_repeat_ngram_size = 35, + use_cache = True + ) + + + if '' in conversation[0]['content'] and eval_mode: + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).npu().shape[1]:]) + stop_str = '<|end▁of▁sentence|>' + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + # re_match + outputs = outputs.strip() + + return outputs + + if '' in conversation[0]['content'] and test_compress: + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).npu().shape[1]:]) + pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) + print('='*50) + print('image size: ', (w, h)) + print('valid image tokens: ', int(valid_img_tokens)) + print('output texts tokens (valid): ', pure_texts_outputs_token_length) + print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) + print('='*50) + + + if '' in conversation[0]['content'] and save_results: + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).npu().shape[1]:]) + stop_str = '<|end▁of▁sentence|>' + + print('='*15 + 'save results:' + '='*15) + + # # # # conv.messages[-1][-1] = outputs + if outputs.endswith(stop_str): + outputs = outputs[:-len(stop_str)] + outputs = outputs.strip() + + matches_ref, matches_images, mathes_other = re_match(outputs) + # print(matches_ref) + result = process_image_with_refs(image_draw, matches_ref, output_path) + + + for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): + outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') + + for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): + outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') + + + # if 'structural formula' in conversation[0]['content']: + # outputs = '' + outputs + '' + with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: + afile.write(outputs) + + if 'line_type' in outputs: + import matplotlib.pyplot as plt + lines = eval(outputs)['Line']['line'] + + line_type = eval(outputs)['Line']['line_type'] + # print(lines) + + endpoints = eval(outputs)['Line']['line_endpoint'] + + fig, ax = plt.subplots(figsize=(3,3), dpi=200) + ax.set_xlim(-15, 15) + ax.set_ylim(-15, 15) + + for idx, line in enumerate(lines): + try: + p0 = eval(line.split(' -- ')[0]) + p1 = eval(line.split(' -- ')[-1]) + + if line_type[idx] == '--': + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') + else: + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') + + ax.scatter(p0[0], p0[1], s=5, color = 'k') + ax.scatter(p1[0], p1[1], s=5, color = 'k') + except: + pass + + for endpoint in endpoints: + + label = endpoint.split(': ')[0] + (x, y) = eval(endpoint.split(': ')[1]) + ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', + fontsize=5, fontweight='light') + + + plt.savefig(f'{output_path}/geo.jpg') + plt.close() + + result.save(f"{output_path}/result_with_boxes.jpg")