1) 修复了to_katex.py的bug; 2)把Box.py中的转化结果写在logs
This commit is contained in:
@@ -1,5 +1,8 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class Point:
|
class Point:
|
||||||
@@ -63,6 +66,9 @@ class Bbox:
|
|||||||
|
|
||||||
|
|
||||||
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
|
def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"):
|
||||||
|
curr_work_dir = Path(os.getcwd())
|
||||||
|
log_dir = curr_work_dir / "logs"
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
drawer = ImageDraw.Draw(img)
|
drawer = ImageDraw.Draw(img)
|
||||||
for bbox in bboxes:
|
for bbox in bboxes:
|
||||||
# Calculate the coordinates for the rectangle to be drawn
|
# Calculate the coordinates for the rectangle to be drawn
|
||||||
@@ -82,4 +88,4 @@ def draw_bboxes(img: Image.Image, bboxes: List[Bbox], name="annotated_image.png"
|
|||||||
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
drawer.text((left, bottom - 10), bbox.content[:10], fill="red")
|
||||||
|
|
||||||
# Save the image with drawn rectangles
|
# Save the image with drawn rectangles
|
||||||
img.save(name)
|
img.save(log_dir / name)
|
||||||
@@ -38,7 +38,7 @@ def train(model, tokenizer, train_dataset, eval_dataset, collate_fn_with_tokeniz
|
|||||||
)
|
)
|
||||||
|
|
||||||
# trainer.train(resume_from_checkpoint=None)
|
# trainer.train(resume_from_checkpoint=None)
|
||||||
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
|
trainer.train(resume_from_checkpoint='/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-788000')
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
def evaluate(model, tokenizer, eval_dataset, collate_fn):
|
||||||
@@ -96,7 +96,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
collate_fn_with_tokenizer = partial(collate_fn, tokenizer=tokenizer)
|
||||||
# model = TexTeller()
|
# model = TexTeller()
|
||||||
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-644000')
|
model = TexTeller.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-788000')
|
||||||
|
|
||||||
# ================= debug =======================
|
# ================= debug =======================
|
||||||
# foo = train_dataset[:50]
|
# foo = train_dataset[:50]
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ def change(input_str, old_inst, new_inst, old_surr_l, old_surr_r, new_surr_l, ne
|
|||||||
i = start + 1
|
i = start + 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
result += input_str[i:start]
|
||||||
i = start
|
i = start
|
||||||
|
|
||||||
if old_inst != new_inst and old_inst in result:
|
if old_inst != new_inst and old_inst in result:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -exu
|
set -exu
|
||||||
|
|
||||||
export CHECKPOINT_DIR="default"
|
export CHECKPOINT_DIR="/home/lhy/code/TexTeller/src/models/ocr_model/train/train_result/TexTellerv3/checkpoint-788000"
|
||||||
export TOKENIZER_DIR="default"
|
export TOKENIZER_DIR="default"
|
||||||
|
|
||||||
streamlit run web.py
|
streamlit run web.py
|
||||||
|
|||||||
Reference in New Issue
Block a user