Source code for dawsonia.prepare
"""
NOTE: This used to be the ``change_format.py`` script.
"""
import random
import shutil
from pathlib import Path
from warnings import warn
import typer
from .config import config_cli_names, config_kwargs
app = typer.Typer()
[docs]
def new_label(label: str) -> str:
"""Encodes a new hyphenated label in the format for Washington dataset"""
new = ""
for char in label:
if char == "-":
char = "s_mi"
if char == ".":
char = "s_pt"
if char == " ":
char = "n"
new += char + "-"
return new[:-1]
[docs]
def convert(name: str, word_end_suffix: bool) -> tuple[str, str]:
"""Seperate file ID and generate new_label."""
name = Path(name).stem
try:
k = name.index("_")
except ValueError as err:
raise IOError(f"Bad filename {name}") from err
new_name = name[:k] + "-0"
if len(name[k:]) > 1:
label = new_label(name[k + 1 :])
else:
label = "n"
if word_end_suffix:
label += "-n"
return new_name + ".png", label
[docs]
@app.callback(invoke_without_command=True)
def command(
n_train: int,
n_val: int,
n_test: int,
label_path: Path = Path(
"/local_disk/", "data", "ai-for-obs", "interim", "label_old"
),
model_path: Path = Path(
"/local_disk/",
"data",
"ai-for-obs",
"interim",
"model_tmp",
),
word_end_suffix: bool = False,
source: str = "washington",
config: Path = typer.Option("dawsonia.toml", *config_cli_names, **config_kwargs),
): # pylint: disable=unused-argument, too-many-arguments, too-many-locals
"""Creates new train, validation, test and ground truth text files in "washington"
format for the HTR network and copies in image as input data for the model.
Parameters\n
----------\n
n_train: int
Number of images in training set
n_val: int
Number of images in validation set
n_test: int
Number if images in test set. NOTE: If n_test == -1, all the files from
label_path would be used for testing
label_path: Path
Path to label directory (where the pictures are located).
"""
# Path to directory where pictures with new labels will be located. Delete
# old files each time new sets are created.
model_source = model_path / "raw" / source
model_input_path = model_source / "data" / "line_images_normalized"
# Paths to ground truths and sets.
ground_truth_path = model_source / "ground_truth"
sets_path = model_source / "sets" / "cv1"
if model_input_path.exists():
print("Removing symlink", model_input_path, "and recreating directory.")
if model_input_path.is_symlink():
model_input_path.unlink()
else:
shutil.rmtree(model_input_path)
model_input_path.mkdir(parents=True)
ground_truth_path.mkdir(exist_ok=True)
sets_path.mkdir(parents=True, exist_ok=True)
# List all files in the label directory.
images = [i for i in label_path.iterdir() if i.suffix == ".png"]
# Create the sets randomly with the specified size. Avoid overlap.
random.shuffle(images)
train_files = images[:n_train]
validation_files = images[n_train : n_train + n_val]
if n_test == -1:
test_files = images
else:
test_files = images[n_train + n_val : n_train + n_val + n_test]
images = train_files + validation_files + test_files
# Write the text files for ground truth, train, validation and test sets.
with (
(ground_truth_path / "transcription.txt").open("w") as transcription,
(sets_path / "train.txt").open("w") as train,
(sets_path / "test.txt").open("w") as test,
(sets_path / "valid.txt").open("w") as validation,
):
for idx, path in enumerate(images):
if (name := path.name) == "temp.png":
continue
# Get the new name and label of the file.
new_name, label = convert(name, word_end_suffix)
# Copy the picture to the directory with the new names.
shutil.copyfile(
path,
(new_path := model_input_path / new_name),
)
new_stem = new_path.stem
# Write the ground truth.
transcription.write(new_stem + " " + label + "\n")
# Write the train, validation and test text files.
if idx < n_train:
train.write(new_stem + "\n")
elif idx < n_train + n_val:
validation.write(new_stem + "\n")
elif idx < n_train + n_val + n_test or n_test == -1:
test.write(new_stem + "\n")
else:
warn(f"File {new_name} was excluded")
print("New train, validation and test sets have been createdlist.")
print(f"{model_input_path = } contains {len(images)} images")
print(f"{ground_truth_path = }")
print(f"{sets_path = }")