Purpose¶
This notebook will guide you through the basic steps to get started with Active Vision.
By the end of this notebook, you will be able to:
- Understand the basic workflow of active learning
- Understand the basic components of Active Vision
- Understand how to use Active Vision to train a model and iteratively improve your dataset
Learner¶
With the initial dataset ready, we can load it into an ActiveLearner
object with a name for the cycle.
from active_vision import ActiveLearner
al = ActiveLearner(name="cycle-1")
Load Model¶
Now let's load a model to be used for active learning. Any fastai
and timm
models are supported.
I'd recommend using a model with a small number of parameters, such as resnet18
to make sure the active learning cycle is fast.
This model will only be used in the active learning cycle to sample the most impactful samples to label.
al.load_model(model="resnet18", pretrained=True)
2025-02-04 21:53:06.162 | INFO | active_vision.core:_detect_optimal_device:87 - Apple Silicon GPU detected - will load model on MPS 2025-02-04 21:53:06.162 | INFO | active_vision.core:load_model:70 - Loading a pretrained timm model `resnet18` on `mps`
Load Initial Dataset¶
First, let's load the initial dataset we prepared earlier.
import pandas as pd
initial_samples = pd.read_parquet("initial_samples.parquet")
initial_samples
filepath | label | |
---|---|---|
0 | data/imagenette/train/n02102040/n02102040_2788... | English springer |
1 | data/imagenette/train/n02102040/n02102040_3759... | English springer |
2 | data/imagenette/train/n02102040/n02102040_1916... | English springer |
3 | data/imagenette/train/n02102040/n02102040_6147... | English springer |
4 | data/imagenette/train/n02102040/n02102040_403.... | English springer |
... | ... | ... |
95 | data/imagenette/train/n01440764/n01440764_1004... | tench |
96 | data/imagenette/train/n01440764/n01440764_3153... | tench |
97 | data/imagenette/train/n01440764/n01440764_1284... | tench |
98 | data/imagenette/train/n01440764/n01440764_3997... | tench |
99 | data/imagenette/train/n01440764/n01440764_2978... | tench |
100 rows × 2 columns
We can load the initial samples into the ActiveLearner
object. Point the filepath_col
and label_col
to the columns in the dataframe.
al.load_dataset(initial_samples, filepath_col="filepath", label_col="label")
2025-02-04 21:53:06.221 | INFO | active_vision.core:load_dataset:119 - Loading dataset from `filepath` and `label` columns 2025-02-04 21:53:06.257 | INFO | active_vision.core:load_dataset:153 - Creating new learner 2025-02-04 21:53:08.147 | INFO | active_vision.core:_optimize_learner:97 - Enabled mixed precision training 2025-02-04 21:53:08.148 | INFO | active_vision.core:_finalize_setup:105 - Done. Ready to train.
Let's inspect one batch of the loaded dataset.
al.show_batch()
You can inspect the train and validation sets too.
al.train_set
filepath | label | |
---|---|---|
74 | data/imagenette/train/n03445777/n03445777_1058... | golf ball |
53 | data/imagenette/train/n03417042/n03417042_9128... | garbage truck |
29 | data/imagenette/train/n02979186/n02979186_7354... | cassette player |
96 | data/imagenette/train/n01440764/n01440764_3153... | tench |
70 | data/imagenette/train/n03445777/n03445777_4354... | golf ball |
... | ... | ... |
89 | data/imagenette/train/n03888257/n03888257_4345... | parachute |
2 | data/imagenette/train/n02102040/n02102040_1916... | English springer |
35 | data/imagenette/train/n03000684/n03000684_1381... | chain saw |
97 | data/imagenette/train/n01440764/n01440764_1284... | tench |
32 | data/imagenette/train/n03000684/n03000684_8985... | chain saw |
80 rows × 2 columns
al.valid_set
filepath | label | |
---|---|---|
75 | data/imagenette/train/n03445777/ILSVRC2012_val... | golf ball |
15 | data/imagenette/train/n03394916/n03394916_3860... | French horn |
98 | data/imagenette/train/n01440764/n01440764_3997... | tench |
90 | data/imagenette/train/n01440764/n01440764_8805... | tench |
31 | data/imagenette/train/n03000684/n03000684_7905... | chain saw |
65 | data/imagenette/train/n03425413/n03425413_2074... | gas pump |
59 | data/imagenette/train/n03417042/n03417042_79.JPEG | garbage truck |
44 | data/imagenette/train/n03028079/n03028079_8632... | church |
30 | data/imagenette/train/n03000684/n03000684_9935... | chain saw |
18 | data/imagenette/train/n03394916/n03394916_3080... | French horn |
58 | data/imagenette/train/n03417042/ILSVRC2012_val... | garbage truck |
11 | data/imagenette/train/n03394916/n03394916_2128... | French horn |
19 | data/imagenette/train/n03394916/n03394916_3543... | French horn |
79 | data/imagenette/train/n03445777/n03445777_1250... | golf ball |
12 | data/imagenette/train/n03394916/n03394916_5342... | French horn |
16 | data/imagenette/train/n03394916/n03394916_4308... | French horn |
64 | data/imagenette/train/n03425413/n03425413_1007... | gas pump |
77 | data/imagenette/train/n03445777/n03445777_9105... | golf ball |
92 | data/imagenette/train/n01440764/n01440764_1520... | tench |
39 | data/imagenette/train/n03000684/n03000684_2753... | chain saw |
Train¶
Now that we have the initial dataset, we can train the model.
But first, let's check the optimal learning rate for the model.
al.lr_find()
2025-02-04 21:53:08.541 | INFO | active_vision.core:lr_find:194 - Finding optimal learning rate /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling warnings.warn( /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py:132: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn(
2025-02-04 21:53:16.820 | INFO | active_vision.core:lr_find:196 - Optimal learning rate: 0.0030199517495930195
Not let's use the optimal learning rate to train the model end-to-end for 10 epochs and 3 epochs of head tuning. In the head tuning phase, the model will be frozen and only the head will be trained.
al.train(epochs=10, lr=5e-3, head_tuning_epochs=3)
2025-02-04 21:53:16.982 | INFO | active_vision.core:train:207 - Training head for 3 epochs 2025-02-04 21:53:16.983 | INFO | active_vision.core:train:208 - Training model end-to-end for 10 epochs 2025-02-04 21:53:16.983 | INFO | active_vision.core:train:209 - Learning rate: 0.005 with one-cycle learning rate scheduler
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 3.548794 | 2.326282 | 0.250000 | 00:00 |
1 | 2.531440 | 0.587719 | 0.850000 | 00:00 |
2 | 1.712109 | 0.455706 | 0.800000 | 00:00 |
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.076002 | 0.375373 | 0.850000 | 00:00 |
1 | 0.071136 | 0.379791 | 0.900000 | 00:00 |
2 | 0.070615 | 0.327519 | 0.900000 | 00:00 |
3 | 0.070802 | 0.502693 | 0.800000 | 00:00 |
4 | 0.082460 | 0.544212 | 0.800000 | 00:00 |
5 | 0.075920 | 0.491695 | 0.800000 | 00:00 |
6 | 0.071243 | 0.527742 | 0.800000 | 00:00 |
7 | 0.067941 | 0.484508 | 0.800000 | 00:00 |
8 | 0.069538 | 0.462621 | 0.800000 | 00:00 |
9 | 0.062958 | 0.493862 | 0.800000 | 00:00 |
Evaluate¶
Now that we have a trained model, we can evaluate it on the evaluation set.
evaluation_df = pd.read_parquet("evaluation_samples.parquet")
evaluation_df
filepath | label | |
---|---|---|
0 | data/imagenette/val/n03394916/n03394916_32422.... | French horn |
1 | data/imagenette/val/n03394916/n03394916_69132.... | French horn |
2 | data/imagenette/val/n03394916/n03394916_33771.... | French horn |
3 | data/imagenette/val/n03394916/n03394916_29940.... | French horn |
4 | data/imagenette/val/n03394916/ILSVRC2012_val_0... | French horn |
... | ... | ... |
3920 | data/imagenette/val/n02979186/n02979186_27392.... | cassette player |
3921 | data/imagenette/val/n02979186/n02979186_2742.JPEG | cassette player |
3922 | data/imagenette/val/n02979186/n02979186_2312.JPEG | cassette player |
3923 | data/imagenette/val/n02979186/n02979186_12822.... | cassette player |
3924 | data/imagenette/val/n02979186/ILSVRC2012_val_0... | cassette player |
3925 rows × 2 columns
al.evaluate(evaluation_df, filepath_col="filepath", label_col="label")
/Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling warnings.warn( /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py:132: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn(
2025-02-04 21:53:40.605 | INFO | active_vision.core:evaluate:285 - Accuracy: 90.24%
0.902420382165605
That is a good start. ~90% accuracy is not bad for a first try with only 100 labeled samples. Let's see if we can improve it.
Let's save the summary of the cycle.
al.summary()
2025-02-04 21:53:50.757 | INFO | active_vision.core:summary:578 - Saved results to cycle-1_20250204_215350_acc_90.24%_n_100.parquet
name | accuracy | train_set_size | valid_set_size | dataset_size | num_classes | model | pretrained | loss_fn | device | seed | batch_size | image_size | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | cycle-1 | 0.90242 | 80 | 20 | 100 | 10 | resnet18 | True | FlattenedLoss of CrossEntropyLoss() | mps | None | 16 | 224 |
The above will create a .parquet
file with the summary of the cycle. This will be useful for tracking the progress of the active learning process.
Predict¶
Using the model, we can predict the labels of the unlabeled samples and get the most impactful samples to label.
df = pd.read_parquet("unlabeled_samples.parquet")
filepaths = df["filepath"].tolist()
len(filepaths)
9369
pred_df = al.predict(filepaths, batch_size=128)
pred_df
2025-02-04 21:54:00.351 | INFO | active_vision.core:predict:216 - Running inference on 9369 samples /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling warnings.warn( /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py:132: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn(
filepath | pred_label | pred_conf | probs | logits | embeddings | |
---|---|---|---|---|---|---|
0 | data/imagenette/train/n03394916/n03394916_4437... | French horn | 0.8815 | [0.0, 0.8815, 0.1002, 0.0007, 0.0056, 0.0015, ... | [-4.0315, 6.3779, 4.2033, -0.7827, 1.3138, 0.0... | [1.5122, 2.9872, -0.0116, -2.0296, -1.1489, -3... |
1 | data/imagenette/train/n03394916/n03394916_4241... | French horn | 0.9981 | [0.0, 0.9981, 0.0002, 0.0, 0.0001, 0.0003, 0.0... | [-4.304, 8.3362, -0.0308, -1.64, -0.8373, 0.17... | [1.8019, -2.0897, -0.1908, -1.2324, 3.1603, -2... |
2 | data/imagenette/train/n03394916/n03394916_3880... | French horn | 0.9605 | [0.0014, 0.9605, 0.0004, 0.0042, 0.0068, 0.017... | [-1.2137, 5.3527, -2.4653, -0.0855, 0.398, 1.3... | [-1.7658, -1.084, 0.1105, 0.6327, -0.3034, -0.... |
3 | data/imagenette/train/n03394916/n03394916_2412... | French horn | 0.9916 | [0.0, 0.9916, 0.0058, 0.0001, 0.0003, 0.0019, ... | [-2.7646, 8.6685, 3.5195, -0.3958, 0.423, 2.42... | [0.5607, -0.0849, -0.0111, -2.5747, 1.7406, -3... |
4 | data/imagenette/train/n03394916/n03394916_1128... | French horn | 0.9345 | [0.0006, 0.9345, 0.0001, 0.0499, 0.0002, 0.009... | [-0.7557, 6.6546, -2.7884, 3.7239, -1.8325, 2.... | [0.0593, -2.8017, -0.8326, 0.0533, 0.3513, -4.... |
... | ... | ... | ... | ... | ... | ... |
9364 | data/imagenette/train/n02979186/n02979186_8089... | cassette player | 0.9998 | [0.0, 0.0, 0.9998, 0.0001, 0.0, 0.0, 0.0, 0.0,... | [0.2201, -0.8615, 11.2814, 2.2068, -2.6583, -2... | [2.7389, -0.1937, 3.834, 3.7718, -1.3812, 3.20... |
9365 | data/imagenette/train/n02979186/n02979186_1944... | cassette player | 0.9978 | [0.0003, 0.0002, 0.9978, 0.0014, 0.0, 0.0003, ... | [1.1453, 0.7587, 9.4202, 2.8624, -1.6734, 1.22... | [0.8961, 0.8141, 3.2657, -1.6927, -0.944, 1.28... |
9366 | data/imagenette/train/n02979186/n02979186_1107... | cassette player | 0.9976 | [0.0001, 0.0, 0.9976, 0.002, 0.0002, 0.0, 0.0,... | [-0.2329, -1.2031, 9.2606, 3.0241, 0.9305, -3.... | [-0.9728, -0.7958, 2.6238, 0.8788, 0.139, 3.13... |
9367 | data/imagenette/train/n02979186/n02979186_2938... | cassette player | 0.9756 | [0.0, 0.0026, 0.9756, 0.0212, 0.0005, 0.0, 0.0... | [-2.3168, 1.9011, 7.8416, 4.0137, 0.2384, -3.5... | [-0.4059, -0.2213, 1.5756, 3.1359, 0.0702, 1.0... |
9368 | data/imagenette/train/n02979186/n02979186_93.JPEG | cassette player | 0.9247 | [0.0, 0.0, 0.9247, 0.0745, 0.0001, 0.0, 0.0, 0... | [-3.7599, -2.0933, 7.9608, 5.4418, -0.9875, -3... | [-1.5722, 0.8393, 0.0824, 0.3304, -1.0368, 1.8... |
9369 rows × 6 columns
Sample¶
With the predicted labels, we can sample the most impactful samples to label using active learning strategies.
For this example, we will use the sample_combination
strategy to sample 50 samples from each strategy listed below in the specified proportions.
samples = al.sample_combination(
pred_df,
num_samples=50,
combination={
"least-confidence": 0.4,
"ratio-of-confidence": 0.2,
"entropy": 0.2,
"model-based-outlier": 0.1,
"random": 0.1,
},
)
samples
2025-02-04 21:54:36.646 | INFO | active_vision.core:sample_combination:498 - Using combination sampling to get 50 samples 2025-02-04 21:54:36.648 | INFO | active_vision.core:sample_uncertain:305 - Using least confidence strategy to get top 20 samples 2025-02-04 21:54:36.654 | INFO | active_vision.core:sample_uncertain:328 - Using ratio of confidence strategy to get top 10 samples 2025-02-04 21:54:36.676 | INFO | active_vision.core:sample_uncertain:342 - Using entropy strategy to get top 10 samples /Users/dnth/Desktop/active-vision/src/active_vision/core.py:345: RuntimeWarning: divide by zero encountered in log2 df.loc[:, "score"] = df["probs"].apply(lambda x: -np.sum(x * np.log2(x))) /Users/dnth/Desktop/active-vision/src/active_vision/core.py:345: RuntimeWarning: invalid value encountered in multiply df.loc[:, "score"] = df["probs"].apply(lambda x: -np.sum(x * np.log2(x))) 2025-02-04 21:54:36.708 | INFO | active_vision.core:sample_diverse:388 - Using model-based outlier strategy to get top 5 samples 2025-02-04 21:54:36.709 | INFO | active_vision.core:predict:216 - Running inference on 20 samples /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling warnings.warn( /Users/dnth/Desktop/active-vision/.venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py:132: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn(
2025-02-04 21:54:36.977 | INFO | active_vision.core:sample_random:460 - Sampling 5 random samples
filepath | strategy | score | pred_label | pred_conf | probs | logits | embeddings | |
---|---|---|---|---|---|---|---|---|
0 | data/imagenette/train/n02979186/n02979186_1154... | least-confidence | 0.7868 | cassette player | 0.2132 | [0.0297, 0.0271, 0.2132, 0.086, 0.1707, 0.0837... | [-0.8746, -0.9678, 1.0962, 0.1883, 0.8735, 0.1... | [-1.48, -1.1786, -2.0082, 1.7485, -0.2949, 0.9... |
1 | data/imagenette/train/n01440764/n01440764_1172... | least-confidence | 0.7713 | chain saw | 0.2287 | [0.0467, 0.0442, 0.0105, 0.2287, 0.17, 0.0001,... | [-0.4178, -0.4728, -1.9077, 1.1707, 0.8743, -6... | [-1.3329, -1.4105, 1.748, -2.604, 0.3246, 1.47... |
2 | data/imagenette/train/n03888257/n03888257_2944... | least-confidence | 0.7420 | garbage truck | 0.2580 | [0.0041, 0.1435, 0.0333, 0.2493, 0.0242, 0.258... | [-2.333, 1.2241, -0.2361, 1.7766, -0.5566, 1.8... | [1.4352, 0.9987, -1.3353, -2.8416, -0.0622, -2... |
3 | data/imagenette/train/n01440764/n01440764_1373... | least-confidence | 0.7331 | French horn | 0.2669 | [0.1597, 0.2669, 0.0141, 0.0733, 0.1193, 0.095... | [2.385, 2.8984, -0.0422, 1.6063, 2.0931, 1.865... | [3.2229, 1.6883, 0.6658, 1.351, 1.4092, 0.1549... |
4 | data/imagenette/train/n03425413/n03425413_1511... | least-confidence | 0.7290 | chain saw | 0.2710 | [0.0507, 0.0071, 0.007, 0.271, 0.0164, 0.1051,... | [0.381, -1.5785, -1.5974, 2.0568, -0.7466, 1.1... | [-0.3913, -1.0495, -1.1727, 2.042, -1.0763, -2... |
5 | data/imagenette/train/n03445777/n03445777_258.... | least-confidence | 0.7272 | cassette player | 0.2728 | [0.0754, 0.0773, 0.2728, 0.0022, 0.0041, 0.009... | [1.4243, 1.4483, 2.71, -2.1106, -1.4994, -0.70... | [-2.4045, 1.845, 0.3447, -1.6225, -0.4469, 2.1... |
6 | data/imagenette/train/n03000684/n03000684_3318... | least-confidence | 0.7246 | gas pump | 0.2754 | [0.0127, 0.0766, 0.0056, 0.161, 0.0054, 0.1688... | [-0.8362, 0.9575, -1.6577, 1.701, -1.6913, 1.7... | [-1.6686, -0.5197, -1.2409, 0.5714, 2.7902, -1... |
7 | data/imagenette/train/n03888257/n03888257_2793... | least-confidence | 0.7245 | golf ball | 0.2755 | [0.1317, 0.1, 0.0025, 0.1613, 0.0339, 0.0605, ... | [1.1779, 0.9028, -2.8041, 1.3812, -0.1785, 0.3... | [-2.3253, 1.5071, 1.5362, 0.4531, 2.2692, 1.80... |
8 | data/imagenette/train/n03425413/n03425413_1914... | least-confidence | 0.7244 | chain saw | 0.2756 | [0.0053, 0.0945, 0.1301, 0.2756, 0.0263, 0.0, ... | [-1.9562, 0.9176, 1.237, 1.9876, -0.3615, -6.7... | [-0.9679, 0.9125, 0.2846, -1.3278, 2.8291, 2.3... |
9 | data/imagenette/train/n03000684/n03000684_3003... | least-confidence | 0.7242 | chain saw | 0.2758 | [0.041, 0.0259, 0.16, 0.2758, 0.0104, 0.2444, ... | [0.2453, -0.2131, 1.6076, 2.1521, -1.1224, 2.0... | [-0.2207, 0.0917, 3.5599, 0.8497, -0.5765, -2.... |
10 | data/imagenette/train/n03000684/n03000684_9113... | least-confidence | 0.7239 | gas pump | 0.2761 | [0.002, 0.0062, 0.2506, 0.1104, 0.0025, 0.002,... | [-2.5362, -1.3905, 2.3038, 1.4836, -2.3233, -2... | [0.1616, 1.2885, -3.0482, -1.261, -0.2587, 3.9... |
11 | data/imagenette/train/n03425413/n03425413_1459... | least-confidence | 0.7222 | French horn | 0.2778 | [0.0094, 0.2778, 0.0052, 0.1642, 0.022, 0.0331... | [-1.5272, 1.861, -2.113, 1.3353, -0.6759, -0.2... | [0.5004, 0.6053, -0.17, -1.9617, 1.2721, -1.23... |
12 | data/imagenette/train/n03000684/n03000684_2098... | least-confidence | 0.7217 | garbage truck | 0.2783 | [0.056, 0.2021, 0.0179, 0.2602, 0.0038, 0.2783... | [1.4888, 2.773, 0.3491, 3.0254, -1.1893, 3.093... | [1.7679, 0.9733, 0.7182, 2.2253, -1.1576, -2.8... |
13 | data/imagenette/train/n03425413/n03425413_1438... | least-confidence | 0.7216 | parachute | 0.2784 | [0.0091, 0.2633, 0.0813, 0.2637, 0.0116, 0.000... | [-1.1265, 2.2387, 1.0636, 2.2402, -0.887, -5.8... | [-2.2686, 0.7315, 2.0473, -0.6264, 0.7306, -0.... |
14 | data/imagenette/train/n03425413/n03425413_2219... | least-confidence | 0.7195 | gas pump | 0.2805 | [0.0012, 0.0552, 0.0386, 0.1362, 0.2309, 0.000... | [-3.3461, 0.4593, 0.1, 1.3619, 1.8895, -6.2735... | [0.2687, -1.6976, -0.3161, -0.5363, 1.1781, 1.... |
15 | data/imagenette/train/n03394916/n03394916_2325... | least-confidence | 0.7194 | parachute | 0.2806 | [0.0002, 0.1631, 0.0047, 0.2694, 0.0122, 0.000... | [-3.7493, 2.8131, -0.7436, 3.3148, 0.2187, -2.... | [-1.3543, 0.2244, -1.4944, -3.7606, 0.3826, -0... |
16 | data/imagenette/train/n03417042/n03417042_5698... | least-confidence | 0.7193 | garbage truck | 0.2807 | [0.0349, 0.1007, 0.064, 0.1788, 0.1128, 0.2807... | [-0.6681, 0.392, -0.0617, 0.9662, 0.5057, 1.41... | [-0.4051, -0.7397, -3.2106, 1.1678, 3.0157, 3.... |
17 | data/imagenette/train/n03000684/n03000684_1503... | least-confidence | 0.7158 | chain saw | 0.2842 | [0.0193, 0.0178, 0.003, 0.2842, 0.0075, 0.2199... | [-0.5842, -0.6629, -2.4605, 2.1054, -1.5322, 1... | [-1.0988, 4.61, 2.2829, 1.96, 2.477, -0.6304, ... |
18 | data/imagenette/train/n03425413/n03425413_3144... | least-confidence | 0.7134 | garbage truck | 0.2866 | [0.0096, 0.0519, 0.0721, 0.1285, 0.1201, 0.286... | [-1.5039, 0.1831, 0.512, 1.09, 1.0222, 1.8924,... | [1.897, -1.5106, -3.1823, 0.8091, 1.1605, -3.1... |
19 | data/imagenette/train/n02102040/n02102040_5983... | least-confidence | 0.7128 | English springer | 0.2872 | [0.2872, 0.0889, 0.011, 0.1836, 0.1843, 0.0035... | [2.7052, 1.5325, -0.554, 2.2576, 2.2614, -1.69... | [0.3121, 0.3837, -0.495, -0.3409, 0.8247, -0.9... |
20 | data/imagenette/train/n03888257/n03888257_5543... | ratio-of-confidence | 0.9963 | church | 0.4892 | [0.0006, 0.0034, 0.0018, 0.0057, 0.4892, 0.004... | [-2.3001, -0.4814, -1.1037, 0.0238, 4.4839, -0... | [1.0376, -0.5689, -0.7547, -0.6402, 3.0029, -1... |
21 | data/imagenette/train/n03417042/n03417042_6903... | ratio-of-confidence | 0.9956 | garbage truck | 0.4954 | [0.0026, 0.0005, 0.0002, 0.0011, 0.4932, 0.495... | [0.0406, -1.6485, -2.7978, -0.864, 5.2682, 5.2... | [3.218, -0.5888, -2.0572, 2.9177, 0.4033, -1.4... |
22 | data/imagenette/train/n03425413/n03425413_1268... | ratio-of-confidence | 0.9955 | gas pump | 0.3320 | [0.0004, 0.1871, 0.1191, 0.3305, 0.0264, 0.000... | [-2.9714, 3.0666, 2.6146, 3.6355, 1.1079, -3.5... | [1.5266, 0.3247, 1.064, -0.6718, 1.9302, 1.412... |
23 | data/imagenette/train/n03417042/n03417042_2684... | ratio-of-confidence | 0.9955 | garbage truck | 0.4902 | [0.0066, 0.0008, 0.0057, 0.007, 0.0004, 0.4902... | [0.3548, -1.7464, 0.2096, 0.4099, -2.4527, 4.6... | [3.4206, -0.3718, -0.3964, 1.7022, 0.4602, -0.... |
24 | data/imagenette/train/n03888257/n03888257_7430... | ratio-of-confidence | 0.9952 | golf ball | 0.5010 | [0.0, 0.0, 0.0, 0.0001, 0.0001, 0.0, 0.0, 0.50... | [-1.9501, -2.6581, -1.7523, -0.7569, -1.1854, ... | [-1.6041, -2.6862, -0.4968, -3.7274, -0.3238, ... |
25 | data/imagenette/train/n03394916/n03394916_2900... | ratio-of-confidence | 0.9941 | French horn | 0.3721 | [0.0007, 0.3721, 0.0, 0.1385, 0.0003, 0.0087, ... | [-2.4833, 3.7691, -5.3913, 2.7805, -3.4254, 0.... | [1.8665, -2.0128, 2.1027, -3.7416, -1.5193, 0.... |
26 | data/imagenette/train/n03425413/n03425413_1731... | ratio-of-confidence | 0.9925 | church | 0.4156 | [0.0166, 0.0078, 0.0151, 0.0237, 0.4156, 0.071... | [-0.2679, -1.0225, -0.366, 0.0846, 2.9498, 1.1... | [-0.3695, -0.1422, -3.6139, 0.3352, -0.2954, 1... |
27 | data/imagenette/train/n02979186/n02979186_2351... | ratio-of-confidence | 0.9908 | gas pump | 0.3271 | [0.0007, 0.0483, 0.22, 0.0028, 0.0114, 0.0025,... | [-3.4812, 0.807, 2.3239, -2.039, -0.6325, -2.1... | [0.3288, 0.9046, -2.2022, -1.245, 1.349, 0.409... |
28 | data/imagenette/train/n01440764/n01440764_1696... | ratio-of-confidence | 0.9904 | tench | 0.3759 | [0.0009, 0.0093, 0.0014, 0.1478, 0.0247, 0.000... | [-2.3892, -0.0999, -1.984, 2.6632, 0.8738, -4.... | [-2.2977, -1.5046, 0.6189, -3.0594, 0.6205, 0.... |
29 | data/imagenette/train/n03394916/n03394916_3011... | ratio-of-confidence | 0.9883 | garbage truck | 0.3082 | [0.0143, 0.1777, 0.0004, 0.3046, 0.0971, 0.308... | [-0.5603, 1.9565, -4.1248, 2.4954, 1.3523, 2.5... | [1.3563, -2.4635, -2.8524, -2.1299, 0.3097, -3... |
30 | data/imagenette/train/n03000684/n03000684_5368... | entropy | 0.8433 | parachute | 0.2922 | [0.0375, 0.0399, 0.0015, 0.1428, 0.1145, 0.118... | [-0.1324, -0.0709, -3.3373, 1.204, 0.9836, 1.0... | [-0.4185, 1.8171, 0.0212, -0.7361, -1.7408, -0... |
31 | data/imagenette/train/n03000684/n03000684_1599... | entropy | 0.8233 | garbage truck | 0.3228 | [0.0605, 0.0334, 0.0598, 0.1211, 0.0053, 0.322... | [0.7622, 0.169, 0.751, 1.4561, -1.6818, 2.4367... | [1.3442, 2.557, 0.2429, -1.9658, -0.7566, -1.4... |
32 | data/imagenette/train/n03000684/n03000684_3103... | entropy | 0.8129 | golf ball | 0.3144 | [0.0239, 0.0719, 0.0997, 0.2182, 0.011, 0.1636... | [-0.2995, 0.8013, 1.1277, 1.9113, -1.0769, 1.6... | [-0.7188, 0.7749, -4.6369, -3.064, 1.9052, -1.... |
33 | data/imagenette/train/n03425413/n03425413_9079... | entropy | 0.7864 | cassette player | 0.2892 | [0.0203, 0.034, 0.2892, 0.0837, 0.1193, 0.0007... | [-1.071, -0.5536, 1.5877, 0.3482, 0.7027, -4.3... | [-3.1902, 0.5344, -1.028, 0.092, 1.9895, 1.347... |
34 | data/imagenette/train/n03028079/n03028079_3835... | entropy | 0.7857 | parachute | 0.3484 | [0.061, 0.0576, 0.0105, 0.0159, 0.2247, 0.0323... | [0.0473, -0.0099, -1.715, -1.3006, 1.3504, -0.... | [-2.0185, -1.2173, -1.2625, -1.6554, -0.2798, ... |
35 | data/imagenette/train/n03000684/ILSVRC2012_val... | entropy | 0.7793 | chain saw | 0.3289 | [0.1086, 0.0298, 0.0004, 0.3289, 0.046, 0.0109... | [1.4685, 0.1735, -4.1946, 2.5763, 0.6098, -0.8... | [-0.822, 2.4415, 0.4211, -0.0497, -2.3272, 2.2... |
36 | data/imagenette/train/n03888257/n03888257_1205... | entropy | 0.7696 | tench | 0.3003 | [0.1932, 0.0357, 0.0008, 0.1847, 0.0814, 0.010... | [2.0675, 0.3785, -3.445, 2.0224, 1.2034, -0.83... | [0.1323, -0.9154, 0.3553, -0.6168, 0.502, 0.45... |
37 | data/imagenette/train/n01440764/n01440764_8589... | entropy | 0.7687 | tench | 0.2955 | [0.1251, 0.147, 0.0012, 0.066, 0.0293, 0.0005,... | [1.1197, 1.2809, -3.5146, 0.4796, -0.3313, -4.... | [-5.7497, 0.9306, -0.2185, -1.5692, 2.0429, 0.... |
38 | data/imagenette/train/n03888257/n03888257_2762... | entropy | 0.7670 | parachute | 0.3022 | [0.037, 0.1207, 0.0027, 0.0345, 0.0034, 0.2532... | [-0.3414, 0.8412, -2.9466, -0.4109, -2.7253, 1... | [-0.0619, 0.3258, 0.089, -0.0403, -1.6754, -2.... |
39 | data/imagenette/train/n03425413/n03425413_1654... | entropy | 0.7662 | gas pump | 0.3035 | [0.0233, 0.1564, 0.0434, 0.1933, 0.0093, 0.210... | [-0.4187, 1.4847, 0.2032, 1.6968, -1.335, 1.78... | [1.0856, -0.4017, -1.1049, 1.7158, 0.2246, -0.... |
40 | data/imagenette/train/n02102040/n02102040_2595... | model-based-outlier | 1.0000 | English springer | 0.9786 | [0.9786, 0.0121, 0.0003, 0.0036, 0.0007, 0.001... | [7.2205, 2.8313, -0.9539, 1.6166, 0.0282, 0.84... | [1.305, -2.4159, 0.4358, -1.7312, -0.004, 0.38... |
41 | data/imagenette/train/n03445777/n03445777_1386... | model-based-outlier | 1.0000 | golf ball | 0.5155 | [0.0034, 0.0257, 0.0205, 0.1739, 0.0073, 0.045... | [-1.3996, 0.6301, 0.402, 2.5419, -0.6247, 1.20... | [0.9203, -0.6261, 0.7931, -0.8476, 0.4564, 0.8... |
42 | data/imagenette/train/n02102040/n02102040_539.... | model-based-outlier | 1.0000 | English springer | 0.9712 | [0.9712, 0.015, 0.0005, 0.0051, 0.0007, 0.0003... | [6.8735, 2.703, -0.7499, 1.6197, -0.4159, -1.2... | [-3.3191, -1.2588, 6.1337, -3.5114, 2.8639, -0... |
43 | data/imagenette/train/n03000684/n03000684_1637... | model-based-outlier | 1.0000 | parachute | 0.8370 | [0.003, 0.012, 0.0008, 0.1107, 0.0142, 0.0131,... | [-0.8472, 0.5192, -2.1964, 2.7446, 0.694, 0.61... | [-2.5737, 1.1527, -1.0046, -0.6686, -0.7702, -... |
44 | data/imagenette/train/n03000684/n03000684_3268... | model-based-outlier | 1.0000 | chain saw | 0.3547 | [0.0345, 0.063, 0.0021, 0.3547, 0.0815, 0.2783... | [0.2586, 0.8616, -2.5209, 2.5899, 1.1198, 2.34... | [-0.3133, -1.8713, -0.4083, -0.5728, -2.0835, ... |
45 | data/imagenette/train/n03000684/n03000684_1471... | random | 0.0000 | chain saw | 0.7763 | [0.045, 0.0082, 0.0005, 0.7763, 0.0134, 0.0777... | [1.0677, -0.6302, -3.3938, 3.9151, -0.1408, 1.... | [1.6981, 0.2913, 2.1714, 2.7012, -1.3052, 0.22... |
46 | data/imagenette/train/n02102040/n02102040_3767... | random | 0.0000 | English springer | 0.9997 | [0.9997, 0.0, 0.0001, 0.0, 0.0001, 0.0, 0.0, 0... | [11.4268, 1.1271, 2.3459, 1.3793, 2.5168, -2.5... | [2.1301, -1.6942, 1.1031, 0.1852, 1.049, -0.23... |
47 | data/imagenette/train/n03425413/n03425413_1957... | random | 0.0000 | gas pump | 0.9993 | [0.0, 0.0002, 0.0001, 0.0002, 0.0001, 0.0, 0.9... | [-3.6787, 0.4349, -0.4148, 0.4399, 0.1826, -3.... | [3.4027, 1.9421, -1.5728, 2.0621, -0.5079, -0.... |
48 | data/imagenette/train/n03028079/n03028079_2035... | random | 0.0000 | church | 0.9973 | [0.0001, 0.001, 0.0001, 0.0, 0.9973, 0.0004, 0... | [-1.6015, 0.9666, -0.9836, -3.0406, 7.8966, 0.... | [0.1738, -0.9133, -1.6448, -0.2947, 2.0651, -1... |
49 | data/imagenette/train/n03028079/n03028079_2268... | random | 0.0000 | church | 0.9999 | [0.0, 0.0, 0.0, 0.0, 0.9999, 0.0, 0.0, 0.0, 0.... | [-1.1582, 0.0911, -1.2376, -2.258, 10.9708, -1... | [3.4548, -0.7209, -4.7961, 1.7589, 2.2121, -2.... |
Label¶
Let's label the samples and save them to a parquet file named combination.parquet
. This should launch a Gradio interface to label the samples.
al.label(samples, output_filename="combination.parquet")
Once you are done labeling, you can inspect the samples you just labeled.
labeled_df = pd.read_parquet("combination.parquet")
labeled_df
filepath | label | |
---|---|---|
0 | data/imagenette/train/n03888257/n03888257_3881... | parachute |
1 | data/imagenette/train/n03417042/n03417042_7047... | garbage truck |
2 | data/imagenette/train/n03028079/n03028079_5956... | church |
3 | data/imagenette/train/n03000684/n03000684_1815... | chain saw |
4 | data/imagenette/train/n01440764/n01440764_1507... | tench |
5 | data/imagenette/train/n03888257/n03888257_2922... | parachute |
6 | data/imagenette/train/n03445777/n03445777_4929... | golf ball |
7 | data/imagenette/train/n02979186/n02979186_2785... | cassette player |
8 | data/imagenette/train/n03888257/n03888257_4289... | parachute |
9 | data/imagenette/train/n03888257/n03888257_2103... | parachute |
10 | data/imagenette/train/n03000684/n03000684_5543... | chain saw |
11 | data/imagenette/train/n03000684/n03000684_1278... | chain saw |
12 | data/imagenette/train/n03888257/n03888257_6468... | parachute |
13 | data/imagenette/train/n03445777/n03445777_1063... | golf ball |
14 | data/imagenette/train/n03445777/n03445777_1641... | golf ball |
15 | data/imagenette/train/n03425413/n03425413_2094... | gas pump |
16 | data/imagenette/train/n02979186/n02979186_2764... | cassette player |
17 | data/imagenette/train/n03394916/n03394916_3318... | French horn |
18 | data/imagenette/train/n03888257/n03888257_7115... | parachute |
19 | data/imagenette/train/n03417042/n03417042_1793... | garbage truck |
20 | data/imagenette/train/n03028079/n03028079_1363... | church |
21 | data/imagenette/train/n02979186/n02979186_4087... | cassette player |
22 | data/imagenette/train/n03417042/n03417042_2143... | garbage truck |
23 | data/imagenette/train/n03000684/n03000684_8799... | chain saw |
24 | data/imagenette/train/n03394916/n03394916_3595... | French horn |
25 | data/imagenette/train/n03417042/n03417042_2594... | garbage truck |
26 | data/imagenette/train/n03425413/n03425413_1124... | gas pump |
27 | data/imagenette/train/n03000684/n03000684_1975... | chain saw |
28 | data/imagenette/train/n02102040/n02102040_5983... | English springer |
29 | data/imagenette/train/n03000684/n03000684_1609... | chain saw |
30 | data/imagenette/train/n03445777/n03445777_2967... | parachute |
31 | data/imagenette/train/n03445777/n03445777_258.... | golf ball |
32 | data/imagenette/train/n02102040/n02102040_1444... | English springer |
33 | data/imagenette/train/n03888257/n03888257_1929... | parachute |
34 | data/imagenette/train/n03888257/n03888257_1770... | parachute |
35 | data/imagenette/train/n03000684/n03000684_9664... | chain saw |
36 | data/imagenette/train/n03417042/n03417042_2236... | garbage truck |
37 | data/imagenette/train/n02102040/n02102040_155.... | English springer |
38 | data/imagenette/train/n03445777/n03445777_9976... | golf ball |
39 | data/imagenette/train/n02979186/n02979186_205.... | cassette player |
40 | data/imagenette/train/n03028079/n03028079_5551... | church |
41 | data/imagenette/train/n02979186/n02979186_966.... | cassette player |
42 | data/imagenette/train/n01440764/n01440764_2043... | tench |
43 | data/imagenette/train/n03417042/n03417042_1869... | garbage truck |
44 | data/imagenette/train/n02102040/n02102040_6763... | English springer |
45 | data/imagenette/train/n01440764/n01440764_1455... | tench |
46 | data/imagenette/train/n03028079/n03028079_2489... | church |
47 | data/imagenette/train/n03425413/n03425413_2110... | gas pump |
Add to dataset¶
Now that we have labeled the samples, we can add them the existing dataset and save it to a parquet file named active_labeled.parquet
.
al.add_to_dataset(labeled_df, output_filename="active_labeled.parquet")
Repeat¶
Congratulations! You have completed the first cycle of active learning. We now have a small dataset and a trained model. Now whats left is to repeat the process of predicting, sampling, labeling, and adding to the train set until we have a good model.
Tracking Progress¶
You can track the progress of the active learning process by inspecting the .parquet files saved when running the al.summary()
function.
In this example, I ran 4 cycles of active learning which resulted in the following files:
cycle-1_20250201_224658_acc_88.69%_n_100.parquet
cycle-2_20250201_225913_acc_92.18%_n_149.parquet
cycle-3_20250201_232340_acc_93.45%_n_195.parquet
cycle-4_20250201_233055_acc_94.52%_n_243.parquet
The name of the file contains the cycle name, the date, the accuracy, and the number of labeled samples.
import glob
import pandas as pd
# Get all parquet files with 'cycle' in the name
cycle_files = glob.glob("cycle-*.parquet")
# Read and concatenate all cycle files
all_cycles_df = pd.concat([pd.read_parquet(f) for f in cycle_files], ignore_index=True)
all_cycles_df = all_cycles_df.sort_values(by="name", ascending=True)
all_cycles_df
name | accuracy | train_set_size | valid_set_size | dataset_size | num_classes | model | pretrained | loss_fn | device | seed | batch_size | image_size | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | cycle-1 | 0.886879 | 80 | 20 | 100 | 10 | resnet18 | True | FlattenedLoss of CrossEntropyLoss() | mps | None | 8 | 224 |
0 | cycle-2 | 0.921783 | 120 | 29 | 149 | 10 | resnet18 | True | FlattenedLoss of CrossEntropyLoss() | mps | None | 8 | 224 |
3 | cycle-3 | 0.934522 | 156 | 39 | 195 | 10 | resnet18 | True | FlattenedLoss of CrossEntropyLoss() | mps | None | 8 | 224 |
2 | cycle-4 | 0.945223 | 195 | 48 | 243 | 10 | resnet18 | True | FlattenedLoss of CrossEntropyLoss() | mps | None | 8 | 224 |
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
sns.lineplot(data=all_cycles_df, x='dataset_size', y='accuracy', marker='o')
plt.title('Model Accuracy vs Dataset Size')
plt.xlabel('Number of Images')
plt.ylabel('Accuracy')
plt.gca().yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1%}'.format(y)))
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()