Source code for dawsonia.prepare

"""

NOTE: This used to be the ``change_format.py`` script.
"""

import random
import shutil
from pathlib import Path
from warnings import warn

import typer

from .config import config_cli_names, config_kwargs

app = typer.Typer()


[docs] def new_label(label: str) -> str: """Encodes a new hyphenated label in the format for Washington dataset""" new = "" for char in label: if char == "-": char = "s_mi" if char == ".": char = "s_pt" if char == " ": char = "n" new += char + "-" return new[:-1]
[docs] def convert(name: str, word_end_suffix: bool) -> tuple[str, str]: """Seperate file ID and generate new_label.""" name = Path(name).stem try: k = name.index("_") except ValueError as err: raise IOError(f"Bad filename {name}") from err new_name = name[:k] + "-0" if len(name[k:]) > 1: label = new_label(name[k + 1 :]) else: label = "n" if word_end_suffix: label += "-n" return new_name + ".png", label
[docs] @app.callback(invoke_without_command=True) def command( n_train: int, n_val: int, n_test: int, label_path: Path = Path( "/local_disk/", "data", "ai-for-obs", "interim", "label_old" ), model_path: Path = Path( "/local_disk/", "data", "ai-for-obs", "interim", "model_tmp", ), word_end_suffix: bool = False, source: str = "washington", config: Path = typer.Option("dawsonia.toml", *config_cli_names, **config_kwargs), ): # pylint: disable=unused-argument, too-many-arguments, too-many-locals """Creates new train, validation, test and ground truth text files in "washington" format for the HTR network and copies in image as input data for the model. Parameters\n ----------\n n_train: int Number of images in training set n_val: int Number of images in validation set n_test: int Number if images in test set. NOTE: If n_test == -1, all the files from label_path would be used for testing label_path: Path Path to label directory (where the pictures are located). """ # Path to directory where pictures with new labels will be located. Delete # old files each time new sets are created. model_source = model_path / "raw" / source model_input_path = model_source / "data" / "line_images_normalized" # Paths to ground truths and sets. ground_truth_path = model_source / "ground_truth" sets_path = model_source / "sets" / "cv1" if model_input_path.exists(): print("Removing symlink", model_input_path, "and recreating directory.") if model_input_path.is_symlink(): model_input_path.unlink() else: shutil.rmtree(model_input_path) model_input_path.mkdir(parents=True) ground_truth_path.mkdir(exist_ok=True) sets_path.mkdir(parents=True, exist_ok=True) # List all files in the label directory. images = [i for i in label_path.iterdir() if i.suffix == ".png"] # Create the sets randomly with the specified size. Avoid overlap. random.shuffle(images) train_files = images[:n_train] validation_files = images[n_train : n_train + n_val] if n_test == -1: test_files = images else: test_files = images[n_train + n_val : n_train + n_val + n_test] images = train_files + validation_files + test_files # Write the text files for ground truth, train, validation and test sets. with ( (ground_truth_path / "transcription.txt").open("w") as transcription, (sets_path / "train.txt").open("w") as train, (sets_path / "test.txt").open("w") as test, (sets_path / "valid.txt").open("w") as validation, ): for idx, path in enumerate(images): if (name := path.name) == "temp.png": continue # Get the new name and label of the file. new_name, label = convert(name, word_end_suffix) # Copy the picture to the directory with the new names. shutil.copyfile( path, (new_path := model_input_path / new_name), ) new_stem = new_path.stem # Write the ground truth. transcription.write(new_stem + " " + label + "\n") # Write the train, validation and test text files. if idx < n_train: train.write(new_stem + "\n") elif idx < n_train + n_val: validation.write(new_stem + "\n") elif idx < n_train + n_val + n_test or n_test == -1: test.write(new_stem + "\n") else: warn(f"File {new_name} was excluded") print("New train, validation and test sets have been createdlist.") print(f"{model_input_path = } contains {len(images)} images") print(f"{ground_truth_path = }") print(f"{sets_path = }")