Classification of mnist dataset

This is an example to show how you can use MapReader on non-patchified datasets

Load

https://mapreader.readthedocs.io/en/latest/User-guide/Load.html

Load images

[1]:
from mapreader import loader

path2images = "./small_mnist/*.png"
my_files = loader(path2images)
100%|███████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2711.21it/s]
[2]:
# len() shows the total number of images currently read (or sliced, see below)
print(f"Number of images: {len(my_files)}")
Number of images: 200
[3]:
print(my_files)
#images: 200

#parents: 200
49081.png
51816.png
39566.png
24251.png
20989.png
29013.png
10692.png
58832.png
30556.png
10686.png
9294.png
...

#patches: 0

[4]:
my_files.show_sample(num_samples=3, tree_level="parent")
../_images/Worked-examples_mnist_pipeline_6_0.png
[5]:
parent_list = my_files.list_parents()

my_files.show(parent_list[0], image_width_resolution=800)
100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 103.47it/s]
[5]:
[<Figure size 1000x1000 with 1 Axes>]
../_images/Worked-examples_mnist_pipeline_7_2.png

Annotate

https://mapreader.readthedocs.io/en/latest/User-guide/Annotate.html

[6]:
from mapreader.annotate.utils import prepare_annotation, save_annotation

Set up inputs

[7]:
userID = "kasra"
annotation_tasks_file = "./annotation_tasks_mnist.yaml"
task = "mnist"
annotation_set = "task_mnist"

annotate = prepare_annotation(
    userID,
    task,
    tree_level="parent",
    annotation_tasks_file=annotation_tasks_file,
    annotation_set=annotation_set,
    sortby="mean",
    min_alpha_channel=0.01,
    xoffset=50,
    yoffset=50,
    context_image=True,
    list_shortcuts=["1", "3"],
)
100%|███████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2907.77it/s]
Number of already annotated images: 164
Number of images to be annotated (total): 36
[WARNING] could not find mean_pixel_R in columns.
Number of images to annotate (current batch): 36

Annotate images and save annotations

[8]:
annotate
[9]:
save_annotation(
    annotate,
    userID,
    task,
    annotation_tasks_file=annotation_tasks_file,
    annotation_set=annotation_set,
)
[INFO] Save 0 new annotations to ./annotations_mnist/mnist_#kasra#.csv
[INFO] 0 labels were not already stored
[INFO] Total number of saved annotations: 164

Classify - Train

https://mapreader.readthedocs.io/en/latest/User-guide/Classify.html

Read annotations

[10]:
from mapreader import AnnotationsLoader
[11]:
annotated_images = AnnotationsLoader()

annotated_images.load("./annotations_mnist/mnist_#kasra#.csv")
[INFO] Reading "./annotations_mnist/mnist_#kasra#.csv"
[INFO] Number of annotations:   164

[INFO] Number of instances of each label (from column "label"):
        - 2:      87
        - 1:      77

[12]:
# show sample images for one label (label_to_show)
annotated_images.show_sample(label_to_show="1", num_samples=6)
../_images/Worked-examples_mnist_pipeline_19_0.png
[13]:
# show an image based on its index
annotated_images.show_patch(patch_id="20989.png")
../_images/Worked-examples_mnist_pipeline_20_0.png

Prepare datasets and dataloaders

[14]:
annotated_images.create_datasets(frac_train=0.7, frac_val=0.2, frac_test=0.1)
[INFO] Number of annotations in each set:
        - Train:        114
        - Validate:     33
        - Test:         17
[15]:
dataloaders = annotated_images.create_dataloaders(batch_size=8, sampler="default")
[INFO] Using default sampler.

Set up ClassifierContainer

Load a (pretrained) PyTorch model and combine with dataloaders

[16]:
from mapreader import ClassifierContainer
[38]:
my_classifier = ClassifierContainer(model="resnet18",
                                    labels_map={0: "3", 1: "1"},
                                    dataloaders=dataloaders
                                    )
[INFO] Device is set to cpu
[INFO] Loaded "train" with 114 items.
[INFO] Loaded "val" with 33 items.
[INFO] Loaded "test" with 17 items.
[INFO] Loaded "all_mnist" with 200 items.
[INFO] Initializing model.
[39]:
my_classifier.model_summary()
===================================================================================================================
Layer (type:depth-idx)                   Output Shape              Output Shape              Param #
===================================================================================================================
ResNet                                   [8, 2]                    [8, 2]                    --
├─Conv2d: 1-1                            [8, 64, 112, 112]         [8, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [8, 64, 112, 112]         [8, 64, 112, 112]         128
├─ReLU: 1-3                              [8, 64, 112, 112]         [8, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [8, 64, 56, 56]           [8, 64, 56, 56]           --
├─Sequential: 1-5                        [8, 64, 56, 56]           [8, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [8, 64, 56, 56]           [8, 64, 56, 56]           --
│    │    └─Conv2d: 3-1                  [8, 64, 56, 56]           [8, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-2             [8, 64, 56, 56]           [8, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [8, 64, 56, 56]           [8, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [8, 64, 56, 56]           [8, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-5             [8, 64, 56, 56]           [8, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [8, 64, 56, 56]           [8, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [8, 64, 56, 56]           [8, 64, 56, 56]           --
│    │    └─Conv2d: 3-7                  [8, 64, 56, 56]           [8, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-8             [8, 64, 56, 56]           [8, 64, 56, 56]           128
│    │    └─ReLU: 3-9                    [8, 64, 56, 56]           [8, 64, 56, 56]           --
│    │    └─Conv2d: 3-10                 [8, 64, 56, 56]           [8, 64, 56, 56]           36,864
│    │    └─BatchNorm2d: 3-11            [8, 64, 56, 56]           [8, 64, 56, 56]           128
│    │    └─ReLU: 3-12                   [8, 64, 56, 56]           [8, 64, 56, 56]           --
├─Sequential: 1-6                        [8, 128, 28, 28]          [8, 128, 28, 28]          --
│    └─BasicBlock: 2-3                   [8, 128, 28, 28]          [8, 128, 28, 28]          --
│    │    └─Conv2d: 3-13                 [8, 128, 28, 28]          [8, 128, 28, 28]          73,728
│    │    └─BatchNorm2d: 3-14            [8, 128, 28, 28]          [8, 128, 28, 28]          256
│    │    └─ReLU: 3-15                   [8, 128, 28, 28]          [8, 128, 28, 28]          --
│    │    └─Conv2d: 3-16                 [8, 128, 28, 28]          [8, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-17            [8, 128, 28, 28]          [8, 128, 28, 28]          256
│    │    └─Sequential: 3-18             [8, 128, 28, 28]          [8, 128, 28, 28]          8,448
│    │    └─ReLU: 3-19                   [8, 128, 28, 28]          [8, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [8, 128, 28, 28]          [8, 128, 28, 28]          --
│    │    └─Conv2d: 3-20                 [8, 128, 28, 28]          [8, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-21            [8, 128, 28, 28]          [8, 128, 28, 28]          256
│    │    └─ReLU: 3-22                   [8, 128, 28, 28]          [8, 128, 28, 28]          --
│    │    └─Conv2d: 3-23                 [8, 128, 28, 28]          [8, 128, 28, 28]          147,456
│    │    └─BatchNorm2d: 3-24            [8, 128, 28, 28]          [8, 128, 28, 28]          256
│    │    └─ReLU: 3-25                   [8, 128, 28, 28]          [8, 128, 28, 28]          --
├─Sequential: 1-7                        [8, 256, 14, 14]          [8, 256, 14, 14]          --
│    └─BasicBlock: 2-5                   [8, 256, 14, 14]          [8, 256, 14, 14]          --
│    │    └─Conv2d: 3-26                 [8, 256, 14, 14]          [8, 256, 14, 14]          294,912
│    │    └─BatchNorm2d: 3-27            [8, 256, 14, 14]          [8, 256, 14, 14]          512
│    │    └─ReLU: 3-28                   [8, 256, 14, 14]          [8, 256, 14, 14]          --
│    │    └─Conv2d: 3-29                 [8, 256, 14, 14]          [8, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-30            [8, 256, 14, 14]          [8, 256, 14, 14]          512
│    │    └─Sequential: 3-31             [8, 256, 14, 14]          [8, 256, 14, 14]          33,280
│    │    └─ReLU: 3-32                   [8, 256, 14, 14]          [8, 256, 14, 14]          --
│    └─BasicBlock: 2-6                   [8, 256, 14, 14]          [8, 256, 14, 14]          --
│    │    └─Conv2d: 3-33                 [8, 256, 14, 14]          [8, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-34            [8, 256, 14, 14]          [8, 256, 14, 14]          512
│    │    └─ReLU: 3-35                   [8, 256, 14, 14]          [8, 256, 14, 14]          --
│    │    └─Conv2d: 3-36                 [8, 256, 14, 14]          [8, 256, 14, 14]          589,824
│    │    └─BatchNorm2d: 3-37            [8, 256, 14, 14]          [8, 256, 14, 14]          512
│    │    └─ReLU: 3-38                   [8, 256, 14, 14]          [8, 256, 14, 14]          --
├─Sequential: 1-8                        [8, 512, 7, 7]            [8, 512, 7, 7]            --
│    └─BasicBlock: 2-7                   [8, 512, 7, 7]            [8, 512, 7, 7]            --
│    │    └─Conv2d: 3-39                 [8, 512, 7, 7]            [8, 512, 7, 7]            1,179,648
│    │    └─BatchNorm2d: 3-40            [8, 512, 7, 7]            [8, 512, 7, 7]            1,024
│    │    └─ReLU: 3-41                   [8, 512, 7, 7]            [8, 512, 7, 7]            --
│    │    └─Conv2d: 3-42                 [8, 512, 7, 7]            [8, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-43            [8, 512, 7, 7]            [8, 512, 7, 7]            1,024
│    │    └─Sequential: 3-44             [8, 512, 7, 7]            [8, 512, 7, 7]            132,096
│    │    └─ReLU: 3-45                   [8, 512, 7, 7]            [8, 512, 7, 7]            --
│    └─BasicBlock: 2-8                   [8, 512, 7, 7]            [8, 512, 7, 7]            --
│    │    └─Conv2d: 3-46                 [8, 512, 7, 7]            [8, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-47            [8, 512, 7, 7]            [8, 512, 7, 7]            1,024
│    │    └─ReLU: 3-48                   [8, 512, 7, 7]            [8, 512, 7, 7]            --
│    │    └─Conv2d: 3-49                 [8, 512, 7, 7]            [8, 512, 7, 7]            2,359,296
│    │    └─BatchNorm2d: 3-50            [8, 512, 7, 7]            [8, 512, 7, 7]            1,024
│    │    └─ReLU: 3-51                   [8, 512, 7, 7]            [8, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [8, 512, 1, 1]            [8, 512, 1, 1]            --
├─Linear: 1-10                           [8, 2]                    [8, 2]                    1,026
===================================================================================================================
Total params: 11,177,538
Trainable params: 11,177,538
Non-trainable params: 0
Total mult-adds (G): 14.51
===================================================================================================================
Input size (MB): 4.82
Forward/backward pass size (MB): 317.92
Params size (MB): 44.71
Estimated Total Size (MB): 367.44
===================================================================================================================

Set up criterion, optimizer and scheduler

[40]:
my_classifier.add_criterion("cross-entropy")
[INFO] Using "CrossEntropyLoss()" as criterion.
[41]:
params_to_optimize = my_classifier.generate_layerwise_lrs(
    min_lr=1e-4, max_lr=1e-3, spacing="geomspace"
)
[42]:
my_classifier.initialize_optimizer(params2optimize=params_to_optimize)
[43]:
my_classifier.initialize_scheduler()

Train/fine-tune a model

[44]:
my_classifier.train(
    num_epochs=5,
    save_model_dir="./models_mnist",
    tensorboard_path="tboard_mnist",
    tmp_file_save_freq=2,
    remove_after_load=False,
    print_info_batch_freq=5,
)
[INFO] Each training step will pass: ['train', 'val'].
2023-06-05 10:32:24 599-JY5FK6 [INFO] train    -- 1/5 --      8/114 (  7.0% ) -- Loss: 0.596
2023-06-05 10:32:26 599-JY5FK6 [INFO] train    -- 1/5 --     48/114 ( 42.1% ) -- Loss: 0.014
2023-06-05 10:32:27 599-JY5FK6 [INFO] train    -- 1/5 --     88/114 ( 77.2% ) -- Loss: 0.001
2023-06-05 10:32:28 599-JY5FK6 [INFO] train    -- 1/5 -- Loss: 0.342; F_macro: 92.97; R_macro: 92.97
2023-06-05 10:32:28 599-JY5FK6 [INFO] val      -- 1/5 --       8/33 ( 24.2% ) -- Loss: 0.802
2023-06-05 10:32:28 599-JY5FK6 [INFO] val      -- 1/5 -- Loss: 0.556; F_macro: 90.60; R_macro: 90.00

2023-06-05 10:32:29 599-JY5FK6 [INFO] train    -- 2/5 --      8/114 (  7.0% ) -- Loss: 0.618
2023-06-05 10:32:30 599-JY5FK6 [INFO] train    -- 2/5 --     48/114 ( 42.1% ) -- Loss: 0.002
2023-06-05 10:32:31 599-JY5FK6 [INFO] train    -- 2/5 --     88/114 ( 77.2% ) -- Loss: 0.059
2023-06-05 10:32:32 599-JY5FK6 [INFO] train    -- 2/5 -- Loss: 0.092; F_macro: 97.37; R_macro: 97.38
2023-06-05 10:32:32 599-JY5FK6 [INFO] val      -- 2/5 --       8/33 ( 24.2% ) -- Loss: 0.000
2023-06-05 10:32:32 599-JY5FK6 [INFO] val      -- 2/5 -- Loss: 0.000; F_macro: 100.00; R_macro: 100.00

[INFO] Checkpoint file saved to "./tmp_checkpoints/tmp_6200949312_checkpoint.pkl".
2023-06-05 10:32:33 599-JY5FK6 [INFO] train    -- 3/5 --      8/114 (  7.0% ) -- Loss: 0.012
2023-06-05 10:32:34 599-JY5FK6 [INFO] train    -- 3/5 --     48/114 ( 42.1% ) -- Loss: 0.005
2023-06-05 10:32:35 599-JY5FK6 [INFO] train    -- 3/5 --     88/114 ( 77.2% ) -- Loss: 0.028
2023-06-05 10:32:36 599-JY5FK6 [INFO] train    -- 3/5 -- Loss: 0.066; F_macro: 98.25; R_macro: 98.28
2023-06-05 10:32:36 599-JY5FK6 [INFO] val      -- 3/5 --       8/33 ( 24.2% ) -- Loss: 0.000
2023-06-05 10:32:37 599-JY5FK6 [INFO] val      -- 3/5 -- Loss: 0.000; F_macro: 100.00; R_macro: 100.00

2023-06-05 10:32:37 599-JY5FK6 [INFO] train    -- 4/5 --      8/114 (  7.0% ) -- Loss: 0.004
2023-06-05 10:32:38 599-JY5FK6 [INFO] train    -- 4/5 --     48/114 ( 42.1% ) -- Loss: 0.001
2023-06-05 10:32:40 599-JY5FK6 [INFO] train    -- 4/5 --     88/114 ( 77.2% ) -- Loss: 0.033
2023-06-05 10:32:41 599-JY5FK6 [INFO] train    -- 4/5 -- Loss: 0.051; F_macro: 99.10; R_macro: 99.24
2023-06-05 10:32:41 599-JY5FK6 [INFO] val      -- 4/5 --       8/33 ( 24.2% ) -- Loss: 0.065
2023-06-05 10:32:41 599-JY5FK6 [INFO] val      -- 4/5 -- Loss: 0.121; F_macro: 96.92; R_macro: 96.67

[INFO] Checkpoint file saved to "./tmp_checkpoints/tmp_6200949312_checkpoint.pkl".
2023-06-05 10:32:41 599-JY5FK6 [INFO] train    -- 5/5 --      8/114 (  7.0% ) -- Loss: 0.000
2023-06-05 10:32:43 599-JY5FK6 [INFO] train    -- 5/5 --     48/114 ( 42.1% ) -- Loss: 0.080
2023-06-05 10:32:44 599-JY5FK6 [INFO] train    -- 5/5 --     88/114 ( 77.2% ) -- Loss: 0.002
2023-06-05 10:32:45 599-JY5FK6 [INFO] train    -- 5/5 -- Loss: 0.065; F_macro: 98.22; R_macro: 98.04
2023-06-05 10:32:45 599-JY5FK6 [INFO] val      -- 5/5 --       8/33 ( 24.2% ) -- Loss: 0.002
2023-06-05 10:32:45 599-JY5FK6 [INFO] val      -- 5/5 -- Loss: 0.002; F_macro: 100.00; R_macro: 100.00

[INFO] Total time: 0m 21s
[INFO] Model at epoch 3 has least valid loss (0.0002) so will be saved.
[INFO] Path: /Users/rwood/LwM/MapReader/worked_examples/non-geospatial/classification_mnist/models_mnist/checkpoint_3.pkl
[45]:
list(my_classifier.metrics.keys())
[45]:
['epoch_loss_train',
 'epoch_prec_micro_train',
 'epoch_recall_micro_train',
 'epoch_fscore_micro_train',
 'epoch_supp_micro_train',
 'epoch_rocauc_micro_train',
 'epoch_prec_macro_train',
 'epoch_recall_macro_train',
 'epoch_fscore_macro_train',
 'epoch_supp_macro_train',
 'epoch_rocauc_macro_train',
 'epoch_prec_weighted_train',
 'epoch_recall_weighted_train',
 'epoch_fscore_weighted_train',
 'epoch_supp_weighted_train',
 'epoch_rocauc_weighted_train',
 'epoch_prec_0_train',
 'epoch_recall_0_train',
 'epoch_fscore_0_train',
 'epoch_supp_0_train',
 'epoch_prec_1_train',
 'epoch_recall_1_train',
 'epoch_fscore_1_train',
 'epoch_supp_1_train',
 'epoch_loss_val',
 'epoch_prec_micro_val',
 'epoch_recall_micro_val',
 'epoch_fscore_micro_val',
 'epoch_supp_micro_val',
 'epoch_rocauc_micro_val',
 'epoch_prec_macro_val',
 'epoch_recall_macro_val',
 'epoch_fscore_macro_val',
 'epoch_supp_macro_val',
 'epoch_rocauc_macro_val',
 'epoch_prec_weighted_val',
 'epoch_recall_weighted_val',
 'epoch_fscore_weighted_val',
 'epoch_supp_weighted_val',
 'epoch_rocauc_weighted_val',
 'epoch_prec_0_val',
 'epoch_recall_0_val',
 'epoch_fscore_0_val',
 'epoch_supp_0_val',
 'epoch_prec_1_val',
 'epoch_recall_1_val',
 'epoch_fscore_1_val',
 'epoch_supp_1_val']
[46]:
my_classifier.plot_metric(
    y_axis=["epoch_loss_train", "epoch_loss_val"],
    y_label="Loss",
    legends=["Train", "Valid"],
    colors=["k", "tab:red"],
)
../_images/Worked-examples_mnist_pipeline_37_0.png
[47]:
my_classifier.plot_metric(
    y_axis=["epoch_rocauc_macro_train", "epoch_rocauc_macro_val"],
    y_label="ROC AUC",
    legends=["Train", "Valid"],
    colors=["k", "tab:red"],
)
../_images/Worked-examples_mnist_pipeline_38_0.png
[48]:
my_classifier.plot_metric(
    y_axis=[
        "epoch_fscore_macro_train",
        "epoch_fscore_macro_val",
        "epoch_fscore_0_val",
        "epoch_fscore_1_val",
    ],
    y_label="F-score",
    legends=[
        "Train",
        "Valid",
        "Valid (label: 0)",
        "Valid (label: 1)",
    ],
    colors=["k", "tab:red", "tab:red", "tab:red"],
    styles=["-", "-", "--", ":"],
    markers=["o", "o", "", ""],
    plt_yrange=[0, 100],
)
../_images/Worked-examples_mnist_pipeline_39_0.png
[49]:
my_classifier.plot_metric(
    y_axis=[
        "epoch_recall_macro_train",
        "epoch_recall_macro_val",
        "epoch_recall_0_val",
        "epoch_recall_1_val",
    ],
    y_label="Recall",
    legends=[
        "Train",
        "Valid",
        "Valid (label: 0)",
        "Valid (label: 1)",
    ],
    colors=["k", "tab:red", "tab:red", "tab:red"],
    styles=["-", "-", "--", ":"],
    markers=["o", "o", "", ""],
    plt_yrange=[0, 100],
)
../_images/Worked-examples_mnist_pipeline_40_0.png

Classify - Infer

https://mapreader.readthedocs.io/en/latest/User-guide/Classify.html

Create dataset with all mnist data and add to ClassifierContainer

[50]:
# create dataframe from MapImages object
parent_df, patch_df = my_files.convert_images()
parent_df.head()
[50]:
parent_id image_path shape
image_id
49081.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28)
51816.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28)
39566.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28)
24251.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28)
20989.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28)
[51]:
from mapreader import PatchDataset
[52]:
data = PatchDataset(parent_df, transform="val")
[53]:
my_classifier.load_dataset(data, set_name="all_mnist")

Infer labels

[54]:
my_classifier.inference("all_mnist")
[INFO] Each training step will pass: ['all_mnist'].
2023-06-05 10:32:46 599-JY5FK6 [INFO] all_mnist -- 6/6 --     16/200 (  8.0% ) -- 
2023-06-05 10:32:49 599-JY5FK6 [INFO] all_mnist -- 6/6 --     96/200 ( 48.0% ) -- 
2023-06-05 10:32:52 599-JY5FK6 [INFO] all_mnist -- 6/6 --    176/200 ( 88.0% ) -- 
[INFO] Total time: 0m 7s
[55]:
my_classifier.show_inference_sample_results(
    label="3", set_name="all_mnist", min_conf=99
)
../_images/Worked-examples_mnist_pipeline_49_0.png

Add predictions to dataframe

[56]:
predictions_df = data.patch_df
[57]:
import numpy as np
import pandas as pd

predictions_df["predicted_label"] = my_classifier.pred_label
predictions_df["pred"] = my_classifier.pred_label_indices
predictions_df["conf"] = np.array(my_classifier.pred_conf).max(axis=1)

predictions_df.head()
[57]:
parent_id image_path shape predicted_label pred conf
image_id
49081.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28) 3 0 0.999915
51816.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28) 3 0 0.999805
39566.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28) 3 0 0.999222
24251.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28) 3 0 0.999954
20989.png None /Users/rwood/LwM/MapReader/worked_examples/non... (28, 28) 3 0 0.999804
[58]:
predictions_df.to_csv("./predictions_df.csv", sep=",", index_label="image_id")