From 36a2680d280422afc6e78a359eddab796579b9af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=89=E6=B4=8B=E4=B8=89=E6=B4=8B?= <1258009915@qq.com> Date: Sat, 22 Jun 2024 22:08:08 +0800 Subject: [PATCH] Update model config --- src/models/ocr_model/model/TexTeller.py | 10 ++++---- src/models/ocr_model/model/config.json | 32 +++++++++++++++++-------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/models/ocr_model/model/TexTeller.py b/src/models/ocr_model/model/TexTeller.py index 381bcf8..0a811f2 100644 --- a/src/models/ocr_model/model/TexTeller.py +++ b/src/models/ocr_model/model/TexTeller.py @@ -18,11 +18,11 @@ from transformers import ( class TexTeller(VisionEncoderDecoderModel): REPO_NAME = 'OleehyO/TexTeller' def __init__(self): - config = VisionEncoderDecoderConfig.from_pretrained('/home/lhy/code/TexTeller/src/models/ocr_model/model/trocr-small') - config.encoder.image_size = FIXED_IMG_SIZE - config.encoder.num_channels = IMG_CHANNELS - config.decoder.vocab_size=VOCAB_SIZE - config.decoder.max_position_embeddings=MAX_TOKEN_SIZE + config = VisionEncoderDecoderConfig.from_pretrained(Path(__file__).resolve().parent / "config.json") + config.encoder.image_size = FIXED_IMG_SIZE + config.encoder.num_channels = IMG_CHANNELS + config.decoder.vocab_size = VOCAB_SIZE + config.decoder.max_position_embeddings = MAX_TOKEN_SIZE super().__init__(config=config) diff --git a/src/models/ocr_model/model/config.json b/src/models/ocr_model/model/config.json index f8ab627..45365ba 100644 --- a/src/models/ocr_model/model/config.json +++ b/src/models/ocr_model/model/config.json @@ -1,4 +1,5 @@ { + "_name_or_path": "OleehyO/TexTeller", "architectures": [ "VisionEncoderDecoderModel" ], @@ -10,9 +11,11 @@ "architectures": null, "attention_dropout": 0.0, "bad_words_ids": null, + "begin_suppress_tokens": null, "bos_token_id": 0, "chunk_size_feed_forward": 0, "classifier_dropout": 0.0, + "cross_attention_hidden_size": 768, "d_model": 1024, "decoder_attention_heads": 16, "decoder_ffn_dim": 4096, @@ -23,9 +26,9 @@ "do_sample": false, "dropout": 0.1, "early_stopping": false, - "cross_attention_hidden_size": 768, "encoder_no_repeat_ngram_size": 0, "eos_token_id": 2, + "exponential_decay_length_penalty": null, "finetuning_task": null, "forced_bos_token_id": null, "forced_eos_token_id": null, @@ -40,9 +43,10 @@ "LABEL_0": 0, "LABEL_1": 1 }, + "layernorm_embedding": true, "length_penalty": 1.0, "max_length": 20, - "max_position_embeddings": 512, + "max_position_embeddings": 1024, "min_length": 0, "model_type": "trocr", "no_repeat_ngram_size": 0, @@ -62,8 +66,10 @@ "return_dict_in_generate": false, "scale_embedding": false, "sep_token_id": null, + "suppress_tokens": null, "task_specific_params": null, "temperature": 1.0, + "tf_legacy_loss": false, "tie_encoder_decoder": false, "tie_word_embeddings": true, "tokenizer_class": null, @@ -71,10 +77,11 @@ "top_p": 1.0, "torch_dtype": null, "torchscript": false, - "transformers_version": "4.12.0.dev0", + "typical_p": 1.0, "use_bfloat16": false, "use_cache": false, - "vocab_size": 50265 + "use_learned_position_embeddings": true, + "vocab_size": 15000 }, "encoder": { "_name_or_path": "", @@ -82,15 +89,18 @@ "architectures": null, "attention_probs_dropout_prob": 0.0, "bad_words_ids": null, + "begin_suppress_tokens": null, "bos_token_id": null, "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, "decoder_start_token_id": null, "diversity_penalty": 0.0, "do_sample": false, "early_stopping": false, - "cross_attention_hidden_size": null, "encoder_no_repeat_ngram_size": 0, + "encoder_stride": 16, "eos_token_id": null, + "exponential_decay_length_penalty": null, "finetuning_task": null, "forced_bos_token_id": null, "forced_eos_token_id": null, @@ -101,7 +111,7 @@ "0": "LABEL_0", "1": "LABEL_1" }, - "image_size": 384, + "image_size": 448, "initializer_range": 0.02, "intermediate_size": 3072, "is_decoder": false, @@ -119,7 +129,7 @@ "num_attention_heads": 12, "num_beam_groups": 1, "num_beams": 1, - "num_channels": 3, + "num_channels": 1, "num_hidden_layers": 12, "num_return_sequences": 1, "output_attentions": false, @@ -136,8 +146,10 @@ "return_dict": true, "return_dict_in_generate": false, "sep_token_id": null, + "suppress_tokens": null, "task_specific_params": null, "temperature": 1.0, + "tf_legacy_loss": false, "tie_encoder_decoder": false, "tie_word_embeddings": true, "tokenizer_class": null, @@ -145,12 +157,12 @@ "top_p": 1.0, "torch_dtype": null, "torchscript": false, - "transformers_version": "4.12.0.dev0", + "typical_p": 1.0, "use_bfloat16": false }, "is_encoder_decoder": true, "model_type": "vision-encoder-decoder", "tie_word_embeddings": false, - "torch_dtype": "float32", - "transformers_version": null + "transformers_version": "4.41.2", + "use_cache": true }