Classification of mnist dataset
This is an example to show how you can use MapReader on non-patchified datasets
Load
https://mapreader.readthedocs.io/en/latest/User-guide/Load.html
Load images
[1]:
from mapreader import loader
path2images = "./small_mnist/*.png"
my_files = loader(path2images)
100%|███████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2711.21it/s]
[2]:
# len() shows the total number of images currently read (or sliced, see below)
print(f"Number of images: {len(my_files)}")
Number of images: 200
[3]:
print(my_files)
#images: 200
#parents: 200
49081.png
51816.png
39566.png
24251.png
20989.png
29013.png
10692.png
58832.png
30556.png
10686.png
9294.png
...
#patches: 0
[4]:
my_files.show_sample(num_samples=3, tree_level="parent")
[5]:
parent_list=my_files.list_parents()
my_files.show(parent_list[0], image_width_resolution=800)
100%|████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 103.47it/s]
[5]:
[<Figure size 1000x1000 with 1 Axes>]
Annotate
https://mapreader.readthedocs.io/en/latest/User-guide/Annotate.html
[6]:
from mapreader.annotate.utils import prepare_annotation, save_annotation
Set up inputs
[7]:
userID = "kasra"
annotation_tasks_file = "./annotation_tasks_mnist.yaml"
task = "mnist"
annotation_set = "task_mnist"
annotate = prepare_annotation(userID,
task,
tree_level="parent",
annotation_tasks_file=annotation_tasks_file,
annotation_set=annotation_set,
sortby="mean",
min_alpha_channel=0.01,
xoffset=50,
yoffset=50,
context_image=True,
list_shortcuts=["1", "3"]
)
100%|███████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 2907.77it/s]
Number of already annotated images: 164
Number of images to be annotated (total): 36
[WARNING] could not find mean_pixel_R in columns.
Number of images to annotate (current batch): 36
Annotate images and save annotations
[8]:
annotate
[9]:
save_annotation(annotate,
userID,
task,
annotation_tasks_file=annotation_tasks_file,
annotation_set=annotation_set)
[INFO] Save 0 new annotations to ./annotations_mnist/mnist_#kasra#.csv
[INFO] 0 labels were not already stored
[INFO] Total number of saved annotations: 164
Classify - Train
https://mapreader.readthedocs.io/en/latest/User-guide/Classify.html
Read annotations
[10]:
from mapreader import AnnotationsLoader
[11]:
annotated_images = AnnotationsLoader()
annotated_images.load("./annotations_mnist/mnist_#kasra#.csv")
[INFO] Reading "./annotations_mnist/mnist_#kasra#.csv"
[INFO] Number of annotations: 164
[INFO] Number of instances of each label (from column "label"):
- 2: 87
- 1: 77
[12]:
# show sample images for one label (label_to_show)
annotated_images.show_sample(label_to_show="1", num_samples=6)
[13]:
# show an image based on its index
annotated_images.show_patch(patch_id="20989.png")
Prepare datasets and dataloaders
[14]:
annotated_images.create_datasets(frac_train=0.7,
frac_val=0.2,
frac_test=0.1)
[INFO] Number of annotations in each set:
- Train: 114
- Validate: 33
- Test: 17
[15]:
dataloaders = annotated_images.create_dataloaders(batch_size=8, sampler="default")
[INFO] Using default sampler.
Set up ClassifierContainer
Load a (pretrained) PyTorch model and combine with dataloaders
[16]:
from mapreader import ClassifierContainer
[38]:
my_classifier = ClassifierContainer(model="resnet18",
dataloaders=dataloaders,
labels_map={0: "3", 1: "1"})
[INFO] Device is set to cpu
[INFO] Loaded "train" with 114 items.
[INFO] Loaded "val" with 33 items.
[INFO] Loaded "test" with 17 items.
[INFO] Loaded "all_mnist" with 200 items.
[INFO] Initializing model.
[39]:
my_classifier.model_summary()
===================================================================================================================
Layer (type:depth-idx) Output Shape Output Shape Param #
===================================================================================================================
ResNet [8, 2] [8, 2] --
├─Conv2d: 1-1 [8, 64, 112, 112] [8, 64, 112, 112] 9,408
├─BatchNorm2d: 1-2 [8, 64, 112, 112] [8, 64, 112, 112] 128
├─ReLU: 1-3 [8, 64, 112, 112] [8, 64, 112, 112] --
├─MaxPool2d: 1-4 [8, 64, 56, 56] [8, 64, 56, 56] --
├─Sequential: 1-5 [8, 64, 56, 56] [8, 64, 56, 56] --
│ └─BasicBlock: 2-1 [8, 64, 56, 56] [8, 64, 56, 56] --
│ │ └─Conv2d: 3-1 [8, 64, 56, 56] [8, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-2 [8, 64, 56, 56] [8, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [8, 64, 56, 56] [8, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [8, 64, 56, 56] [8, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-5 [8, 64, 56, 56] [8, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [8, 64, 56, 56] [8, 64, 56, 56] --
│ └─BasicBlock: 2-2 [8, 64, 56, 56] [8, 64, 56, 56] --
│ │ └─Conv2d: 3-7 [8, 64, 56, 56] [8, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-8 [8, 64, 56, 56] [8, 64, 56, 56] 128
│ │ └─ReLU: 3-9 [8, 64, 56, 56] [8, 64, 56, 56] --
│ │ └─Conv2d: 3-10 [8, 64, 56, 56] [8, 64, 56, 56] 36,864
│ │ └─BatchNorm2d: 3-11 [8, 64, 56, 56] [8, 64, 56, 56] 128
│ │ └─ReLU: 3-12 [8, 64, 56, 56] [8, 64, 56, 56] --
├─Sequential: 1-6 [8, 128, 28, 28] [8, 128, 28, 28] --
│ └─BasicBlock: 2-3 [8, 128, 28, 28] [8, 128, 28, 28] --
│ │ └─Conv2d: 3-13 [8, 128, 28, 28] [8, 128, 28, 28] 73,728
│ │ └─BatchNorm2d: 3-14 [8, 128, 28, 28] [8, 128, 28, 28] 256
│ │ └─ReLU: 3-15 [8, 128, 28, 28] [8, 128, 28, 28] --
│ │ └─Conv2d: 3-16 [8, 128, 28, 28] [8, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-17 [8, 128, 28, 28] [8, 128, 28, 28] 256
│ │ └─Sequential: 3-18 [8, 128, 28, 28] [8, 128, 28, 28] 8,448
│ │ └─ReLU: 3-19 [8, 128, 28, 28] [8, 128, 28, 28] --
│ └─BasicBlock: 2-4 [8, 128, 28, 28] [8, 128, 28, 28] --
│ │ └─Conv2d: 3-20 [8, 128, 28, 28] [8, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-21 [8, 128, 28, 28] [8, 128, 28, 28] 256
│ │ └─ReLU: 3-22 [8, 128, 28, 28] [8, 128, 28, 28] --
│ │ └─Conv2d: 3-23 [8, 128, 28, 28] [8, 128, 28, 28] 147,456
│ │ └─BatchNorm2d: 3-24 [8, 128, 28, 28] [8, 128, 28, 28] 256
│ │ └─ReLU: 3-25 [8, 128, 28, 28] [8, 128, 28, 28] --
├─Sequential: 1-7 [8, 256, 14, 14] [8, 256, 14, 14] --
│ └─BasicBlock: 2-5 [8, 256, 14, 14] [8, 256, 14, 14] --
│ │ └─Conv2d: 3-26 [8, 256, 14, 14] [8, 256, 14, 14] 294,912
│ │ └─BatchNorm2d: 3-27 [8, 256, 14, 14] [8, 256, 14, 14] 512
│ │ └─ReLU: 3-28 [8, 256, 14, 14] [8, 256, 14, 14] --
│ │ └─Conv2d: 3-29 [8, 256, 14, 14] [8, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-30 [8, 256, 14, 14] [8, 256, 14, 14] 512
│ │ └─Sequential: 3-31 [8, 256, 14, 14] [8, 256, 14, 14] 33,280
│ │ └─ReLU: 3-32 [8, 256, 14, 14] [8, 256, 14, 14] --
│ └─BasicBlock: 2-6 [8, 256, 14, 14] [8, 256, 14, 14] --
│ │ └─Conv2d: 3-33 [8, 256, 14, 14] [8, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-34 [8, 256, 14, 14] [8, 256, 14, 14] 512
│ │ └─ReLU: 3-35 [8, 256, 14, 14] [8, 256, 14, 14] --
│ │ └─Conv2d: 3-36 [8, 256, 14, 14] [8, 256, 14, 14] 589,824
│ │ └─BatchNorm2d: 3-37 [8, 256, 14, 14] [8, 256, 14, 14] 512
│ │ └─ReLU: 3-38 [8, 256, 14, 14] [8, 256, 14, 14] --
├─Sequential: 1-8 [8, 512, 7, 7] [8, 512, 7, 7] --
│ └─BasicBlock: 2-7 [8, 512, 7, 7] [8, 512, 7, 7] --
│ │ └─Conv2d: 3-39 [8, 512, 7, 7] [8, 512, 7, 7] 1,179,648
│ │ └─BatchNorm2d: 3-40 [8, 512, 7, 7] [8, 512, 7, 7] 1,024
│ │ └─ReLU: 3-41 [8, 512, 7, 7] [8, 512, 7, 7] --
│ │ └─Conv2d: 3-42 [8, 512, 7, 7] [8, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-43 [8, 512, 7, 7] [8, 512, 7, 7] 1,024
│ │ └─Sequential: 3-44 [8, 512, 7, 7] [8, 512, 7, 7] 132,096
│ │ └─ReLU: 3-45 [8, 512, 7, 7] [8, 512, 7, 7] --
│ └─BasicBlock: 2-8 [8, 512, 7, 7] [8, 512, 7, 7] --
│ │ └─Conv2d: 3-46 [8, 512, 7, 7] [8, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-47 [8, 512, 7, 7] [8, 512, 7, 7] 1,024
│ │ └─ReLU: 3-48 [8, 512, 7, 7] [8, 512, 7, 7] --
│ │ └─Conv2d: 3-49 [8, 512, 7, 7] [8, 512, 7, 7] 2,359,296
│ │ └─BatchNorm2d: 3-50 [8, 512, 7, 7] [8, 512, 7, 7] 1,024
│ │ └─ReLU: 3-51 [8, 512, 7, 7] [8, 512, 7, 7] --
├─AdaptiveAvgPool2d: 1-9 [8, 512, 1, 1] [8, 512, 1, 1] --
├─Linear: 1-10 [8, 2] [8, 2] 1,026
===================================================================================================================
Total params: 11,177,538
Trainable params: 11,177,538
Non-trainable params: 0
Total mult-adds (G): 14.51
===================================================================================================================
Input size (MB): 4.82
Forward/backward pass size (MB): 317.92
Params size (MB): 44.71
Estimated Total Size (MB): 367.44
===================================================================================================================
Set up criterion, optimizer and scheduler
[40]:
my_classifier.add_criterion("cross-entropy")
[INFO] Using "CrossEntropyLoss()" as criterion.
[41]:
params_to_optimize = my_classifier.generate_layerwise_lrs(min_lr=1e-4, max_lr=1e-3, spacing='geomspace')
[42]:
my_classifier.initialize_optimizer(params2optimize=params_to_optimize)
[43]:
my_classifier.initialize_scheduler()
Train/fine-tune a model
[44]:
my_classifier.train(num_epochs=5,
save_model_dir="./models_mnist",
tensorboard_path="tboard_mnist",
tmp_file_save_freq=2,
remove_after_load=False,
print_info_batch_freq=5)
[INFO] Each training step will pass: ['train', 'val'].
2023-06-05 10:32:24 599-JY5FK6 [INFO] train -- 1/5 -- 8/114 ( 7.0% ) -- Loss: 0.596
2023-06-05 10:32:26 599-JY5FK6 [INFO] train -- 1/5 -- 48/114 ( 42.1% ) -- Loss: 0.014
2023-06-05 10:32:27 599-JY5FK6 [INFO] train -- 1/5 -- 88/114 ( 77.2% ) -- Loss: 0.001
2023-06-05 10:32:28 599-JY5FK6 [INFO] train -- 1/5 -- Loss: 0.342; F_macro: 92.97; R_macro: 92.97
2023-06-05 10:32:28 599-JY5FK6 [INFO] val -- 1/5 -- 8/33 ( 24.2% ) -- Loss: 0.802
2023-06-05 10:32:28 599-JY5FK6 [INFO] val -- 1/5 -- Loss: 0.556; F_macro: 90.60; R_macro: 90.00
2023-06-05 10:32:29 599-JY5FK6 [INFO] train -- 2/5 -- 8/114 ( 7.0% ) -- Loss: 0.618
2023-06-05 10:32:30 599-JY5FK6 [INFO] train -- 2/5 -- 48/114 ( 42.1% ) -- Loss: 0.002
2023-06-05 10:32:31 599-JY5FK6 [INFO] train -- 2/5 -- 88/114 ( 77.2% ) -- Loss: 0.059
2023-06-05 10:32:32 599-JY5FK6 [INFO] train -- 2/5 -- Loss: 0.092; F_macro: 97.37; R_macro: 97.38
2023-06-05 10:32:32 599-JY5FK6 [INFO] val -- 2/5 -- 8/33 ( 24.2% ) -- Loss: 0.000
2023-06-05 10:32:32 599-JY5FK6 [INFO] val -- 2/5 -- Loss: 0.000; F_macro: 100.00; R_macro: 100.00
[INFO] Checkpoint file saved to "./tmp_checkpoints/tmp_6200949312_checkpoint.pkl".
2023-06-05 10:32:33 599-JY5FK6 [INFO] train -- 3/5 -- 8/114 ( 7.0% ) -- Loss: 0.012
2023-06-05 10:32:34 599-JY5FK6 [INFO] train -- 3/5 -- 48/114 ( 42.1% ) -- Loss: 0.005
2023-06-05 10:32:35 599-JY5FK6 [INFO] train -- 3/5 -- 88/114 ( 77.2% ) -- Loss: 0.028
2023-06-05 10:32:36 599-JY5FK6 [INFO] train -- 3/5 -- Loss: 0.066; F_macro: 98.25; R_macro: 98.28
2023-06-05 10:32:36 599-JY5FK6 [INFO] val -- 3/5 -- 8/33 ( 24.2% ) -- Loss: 0.000
2023-06-05 10:32:37 599-JY5FK6 [INFO] val -- 3/5 -- Loss: 0.000; F_macro: 100.00; R_macro: 100.00
2023-06-05 10:32:37 599-JY5FK6 [INFO] train -- 4/5 -- 8/114 ( 7.0% ) -- Loss: 0.004
2023-06-05 10:32:38 599-JY5FK6 [INFO] train -- 4/5 -- 48/114 ( 42.1% ) -- Loss: 0.001
2023-06-05 10:32:40 599-JY5FK6 [INFO] train -- 4/5 -- 88/114 ( 77.2% ) -- Loss: 0.033
2023-06-05 10:32:41 599-JY5FK6 [INFO] train -- 4/5 -- Loss: 0.051; F_macro: 99.10; R_macro: 99.24
2023-06-05 10:32:41 599-JY5FK6 [INFO] val -- 4/5 -- 8/33 ( 24.2% ) -- Loss: 0.065
2023-06-05 10:32:41 599-JY5FK6 [INFO] val -- 4/5 -- Loss: 0.121; F_macro: 96.92; R_macro: 96.67
[INFO] Checkpoint file saved to "./tmp_checkpoints/tmp_6200949312_checkpoint.pkl".
2023-06-05 10:32:41 599-JY5FK6 [INFO] train -- 5/5 -- 8/114 ( 7.0% ) -- Loss: 0.000
2023-06-05 10:32:43 599-JY5FK6 [INFO] train -- 5/5 -- 48/114 ( 42.1% ) -- Loss: 0.080
2023-06-05 10:32:44 599-JY5FK6 [INFO] train -- 5/5 -- 88/114 ( 77.2% ) -- Loss: 0.002
2023-06-05 10:32:45 599-JY5FK6 [INFO] train -- 5/5 -- Loss: 0.065; F_macro: 98.22; R_macro: 98.04
2023-06-05 10:32:45 599-JY5FK6 [INFO] val -- 5/5 -- 8/33 ( 24.2% ) -- Loss: 0.002
2023-06-05 10:32:45 599-JY5FK6 [INFO] val -- 5/5 -- Loss: 0.002; F_macro: 100.00; R_macro: 100.00
[INFO] Total time: 0m 21s
[INFO] Model at epoch 3 has least valid loss (0.0002) so will be saved.
[INFO] Path: /Users/rwood/LwM/MapReader/worked_examples/non-geospatial/classification_mnist/models_mnist/checkpoint_3.pkl
[45]:
list(my_classifier.metrics.keys())
[45]:
['epoch_loss_train',
'epoch_prec_micro_train',
'epoch_recall_micro_train',
'epoch_fscore_micro_train',
'epoch_supp_micro_train',
'epoch_rocauc_micro_train',
'epoch_prec_macro_train',
'epoch_recall_macro_train',
'epoch_fscore_macro_train',
'epoch_supp_macro_train',
'epoch_rocauc_macro_train',
'epoch_prec_weighted_train',
'epoch_recall_weighted_train',
'epoch_fscore_weighted_train',
'epoch_supp_weighted_train',
'epoch_rocauc_weighted_train',
'epoch_prec_0_train',
'epoch_recall_0_train',
'epoch_fscore_0_train',
'epoch_supp_0_train',
'epoch_prec_1_train',
'epoch_recall_1_train',
'epoch_fscore_1_train',
'epoch_supp_1_train',
'epoch_loss_val',
'epoch_prec_micro_val',
'epoch_recall_micro_val',
'epoch_fscore_micro_val',
'epoch_supp_micro_val',
'epoch_rocauc_micro_val',
'epoch_prec_macro_val',
'epoch_recall_macro_val',
'epoch_fscore_macro_val',
'epoch_supp_macro_val',
'epoch_rocauc_macro_val',
'epoch_prec_weighted_val',
'epoch_recall_weighted_val',
'epoch_fscore_weighted_val',
'epoch_supp_weighted_val',
'epoch_rocauc_weighted_val',
'epoch_prec_0_val',
'epoch_recall_0_val',
'epoch_fscore_0_val',
'epoch_supp_0_val',
'epoch_prec_1_val',
'epoch_recall_1_val',
'epoch_fscore_1_val',
'epoch_supp_1_val']
[46]:
my_classifier.plot_metric(y_axis=["epoch_loss_train", "epoch_loss_val"],
y_label="Loss",
legends=["Train", "Valid"],
colors=["k", "tab:red"])
[47]:
my_classifier.plot_metric(y_axis=["epoch_rocauc_macro_train", "epoch_rocauc_macro_val"],
y_label="ROC AUC",
legends=["Train", "Valid"],
colors=["k", "tab:red"])
[48]:
my_classifier.plot_metric(y_axis=["epoch_fscore_macro_train",
"epoch_fscore_macro_val",
"epoch_fscore_0_val",
"epoch_fscore_1_val"],
y_label="F-score",
legends=["Train",
"Valid",
"Valid (label: 0)",
"Valid (label: 1)",],
colors=["k", "tab:red", "tab:red", "tab:red"],
styles=["-", "-", "--", ":"],
markers=["o", "o", "", ""],
plt_yrange=[0, 100])
[49]:
my_classifier.plot_metric(y_axis=["epoch_recall_macro_train",
"epoch_recall_macro_val",
"epoch_recall_0_val",
"epoch_recall_1_val"],
y_label="Recall",
legends=["Train",
"Valid",
"Valid (label: 0)",
"Valid (label: 1)",],
colors=["k", "tab:red", "tab:red", "tab:red"],
styles=["-", "-", "--", ":"],
markers=["o", "o", "", ""],
plt_yrange=[0, 100])
Classify - Infer
https://mapreader.readthedocs.io/en/latest/User-guide/Classify.html
Create dataset with all mnist data and add to ClassifierContainer
[50]:
#create dataframe from MapImages object
parent_df, patch_df = my_files.convert_images()
parent_df.head()
[50]:
parent_id | image_path | shape | |
---|---|---|---|
image_id | |||
49081.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) |
51816.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) |
39566.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) |
24251.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) |
20989.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) |
[51]:
from mapreader import PatchDataset
[52]:
data = PatchDataset(parent_df, transform="val")
[53]:
my_classifier.load_dataset(data, set_name="all_mnist")
Infer labels
[54]:
my_classifier.inference("all_mnist")
[INFO] Each training step will pass: ['all_mnist'].
2023-06-05 10:32:46 599-JY5FK6 [INFO] all_mnist -- 6/6 -- 16/200 ( 8.0% ) --
2023-06-05 10:32:49 599-JY5FK6 [INFO] all_mnist -- 6/6 -- 96/200 ( 48.0% ) --
2023-06-05 10:32:52 599-JY5FK6 [INFO] all_mnist -- 6/6 -- 176/200 ( 88.0% ) --
[INFO] Total time: 0m 7s
[55]:
my_classifier.show_inference_sample_results(label="3", set_name="all_mnist", min_conf=99)
Add predictions to dataframe
[56]:
predictions_df = data.patch_df
[57]:
import numpy as np
import pandas as pd
predictions_df['predicted_label'] = my_classifier.pred_label
predictions_df['pred'] = my_classifier.pred_label_indices
predictions_df['conf'] = np.array(my_classifier.pred_conf).max(axis=1)
predictions_df.head()
[57]:
parent_id | image_path | shape | predicted_label | pred | conf | |
---|---|---|---|---|---|---|
image_id | ||||||
49081.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) | 3 | 0 | 0.999915 |
51816.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) | 3 | 0 | 0.999805 |
39566.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) | 3 | 0 | 0.999222 |
24251.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) | 3 | 0 | 0.999954 |
20989.png | None | /Users/rwood/LwM/MapReader/worked_examples/non... | (28, 28) | 3 | 0 | 0.999804 |
[58]:
predictions_df.to_csv("./predictions_df.csv", sep=",", index_label="image_id")