Convolutional Neural Networks (CNNs) tutorial for CS-772 2025
Author: Jimut Bahan Pal
Install the necessary libraries¶
In [ ]:
! pip install -q gdown
! pip install -q colorama # for color-based texts
! pip install -q torchviz # for visualizing graphs
! pip install -q torchview # for visualizing graphs
! pip install -q graphviz # for visualizing graphs
! pip install -q torchsummary # for finding the number of parameters of a model
In [ ]:
# delete dataset if already present
! rm -rf dataset.zip
Download the dataset from Google Drive¶
In [ ]:
! gdown --id 1Ihjzt59qzBmI6JbBgJZpZyTg1wjZvKnE -O dataset.zip
/usr/local/lib/python3.12/dist-packages/gdown/__main__.py:140: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID. warnings.warn( Downloading... From (original): https://drive.google.com/uc?id=1Ihjzt59qzBmI6JbBgJZpZyTg1wjZvKnE From (redirected): https://drive.google.com/uc?id=1Ihjzt59qzBmI6JbBgJZpZyTg1wjZvKnE&confirm=t&uuid=c5388783-590f-4505-b990-9ace5b78ac10 To: /content/dataset.zip 100% 707M/707M [00:10<00:00, 66.7MB/s]
In [ ]:
! unzip -qq dataset.zip
Check the number of images present in the whole dataset¶
In [ ]:
! ls weed_augmented/images/* | wc
7872 7872 259776
Imports¶
In [ ]:
# General imports
import os
import cv2
import glob
import numpy as np
import random
from tqdm import tqdm
from pathlib import Path
from colorama import Fore, Style
from collections import Counter, defaultdict
# Model based imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# Viz-based imports
from torchsummary import summary
from torchview import draw_graph
import graphviz
graphviz.set_jupyter_format('png')
import seaborn as sns
import matplotlib.pyplot as plt
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
Using device: cuda
This dataset contains three pixel classes; here, we examine the percentage of pixels belonging to each class.¶
In [ ]:
def analyze_3class_masks(mask_folder="./weed_augmented/masks/"):
"""
Fast analysis for 3-class segmentation masks with specific RGB values.
"""
# Define the 3 classes
classes = {
'background': np.array([0, 0, 0]), # Black
'class_1': np.array([0, 128, 0]), # Green
'class_2': np.array([0, 0, 128]) # Blue
}
# Initialize counters
counts = {name: 0 for name in classes.keys()}
total_pixels = 0
# Get all mask files
mask_files = glob.glob(f"{mask_folder}*.png")
print(f"Processing {len(mask_files)} mask files...")
# Process each mask
for mask_file in tqdm(mask_files, desc="Analyzing masks"):
# Read mask (BGR format)
mask_rgb = cv2.imread(mask_file)
if mask_rgb is None:
continue
# Convert BGR to RGB
# mask_rgb = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
h, w = mask_rgb.shape[:2]
total_pixels += h * w
# Count pixels for each class using vectorized operations
for class_name, class_color in classes.items():
# Create boolean mask for this class
matches = np.all(mask_rgb == class_color, axis=2)
counts[class_name] += np.sum(matches)
# Calculate percentages
print(f"\n{'Class':<12} {'Pixels':<12} {'Percentage'}")
print("-" * 35)
for class_name, pixel_count in counts.items():
percentage = (pixel_count / total_pixels) * 100 if total_pixels > 0 else 0
print(f"{class_name:<12} {pixel_count:<12,} {percentage:.2f}%")
print(f"\nTotal pixels: {total_pixels:,}")
print(f"Files processed: {len(mask_files)}")
return counts, total_pixels
analyze_3class_masks()
Processing 7872 mask files...
Analyzing masks: 100%|██████████| 7872/7872 [03:49<00:00, 34.25it/s]
Class Pixels Percentage ----------------------------------- background 1,662,369,971 80.56% class_1 193,860,104 9.39% class_2 206,718,645 10.02% Total pixels: 2,063,597,568 Files processed: 7872
Out[ ]:
({'background': np.int64(1662369971), 'class_1': np.int64(193860104), 'class_2': np.int64(206718645)}, 2063597568)
Creating the train-val-test split, with 60% of the dataset for training, 10% of the dataset for validation and 30% of the dataset for testing; also resizing the images to 256x256 size before creating the dataset. Store the image masks pairs in the respective folders, i.e., train_images, train_masks, val_images, val_masks, test_images, and test_masks.
¶
In [ ]:
# Configuration
source_folder = 'weed_augmented'
output_folder = '.'
# use a seed of 42 for reproducibility
# random.seed(42)
# Create directories
dirs = ['train_images', 'train_masks', 'val_images', 'val_masks', 'test_images', 'test_masks']
for dir_name in dirs:
os.makedirs(os.path.join(output_folder, dir_name), exist_ok=True)
# Get matching image-mask pairs
images_path = os.path.join(source_folder, 'images')
masks_path = os.path.join(source_folder, 'masks')
image_files = [f for f in os.listdir(images_path) if f.lower().endswith('.jpg')]
pairs = []
for img_file in image_files:
mask_file = img_file.replace('.jpg', '.png')
if os.path.exists(os.path.join(masks_path, mask_file)):
pairs.append((img_file, mask_file))
print(f"Found {len(pairs)} matching image-mask pairs")
# Shuffle and split
random.shuffle(pairs)
n_total = len(pairs)
n_train = int(n_total * 0.6)
n_val = int(n_total * 0.1)
train_pairs = pairs[:n_train]
val_pairs = pairs[n_train:n_train + n_val]
test_pairs = pairs[n_train + n_val:]
print(f"Split: Train={len(train_pairs)}, Val={len(val_pairs)}, Test={len(test_pairs)}")
# Process and resize images
target_size = (256, 256)
for pairs_list, split_name in [(train_pairs, 'train'), (val_pairs, 'val'), (test_pairs, 'test')]:
for img_file, mask_file_name in tqdm(pairs_list):
# Load and resize image
img = cv2.imread(os.path.join(images_path, img_file))
img_resized = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
# Load and resize mask
mask = cv2.imread(os.path.join(masks_path, mask_file_name))
mask_resized = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
# Save resized images
cv2.imwrite(os.path.join(output_folder, f'{split_name}_images', mask_file_name), img_resized)
cv2.imwrite(os.path.join(output_folder, f'{split_name}_masks', mask_file_name), mask_resized)
Found 7872 matching image-mask pairs Split: Train=4723, Val=787, Test=2362
100%|██████████| 4723/4723 [00:49<00:00, 94.84it/s] 100%|██████████| 787/787 [00:08<00:00, 95.92it/s] 100%|██████████| 2362/2362 [00:26<00:00, 89.98it/s]
Creating overlays of the masks with the images, and showing first 5 samples from each of the train, validation and test splits¶
In [ ]:
# Analyze mask colors
mask_folder = os.path.join(output_folder, 'train_masks')
mask_files = [f for f in os.listdir(mask_folder) if f.lower().endswith('.png')]
sample_files = random.sample(mask_files, min(20, len(mask_files)))
all_colors = []
for filename in sample_files:
mask = cv2.imread(os.path.join(mask_folder, filename))
mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
unique_colors = np.unique(mask_rgb.reshape(-1, 3), axis=0)
all_colors.extend([tuple(color) for color in unique_colors])
color_counts = Counter(all_colors)
class_colors = [color for color, _ in color_counts.most_common(10)]
print("Most common mask colors (RGB):")
for color in class_colors[:5]:
print(f" {color}")
# Visualize samples
def create_overlay(image, mask, colors, alpha=0.6):
overlay = image.copy()
mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
for color in colors:
if color == (0, 0, 0):
continue
color_mask = np.all(mask_rgb == color, axis=2)
if np.any(color_mask):
overlay[color_mask] = (alpha * np.array(color) + (1 - alpha) * overlay[color_mask]).astype(np.uint8)
return overlay
# Show samples for each split
for split in ['train', 'val', 'test']:
img_folder = os.path.join(output_folder, f'{split}_images')
mask_folder = os.path.join(output_folder, f'{split}_masks')
img_files = sorted([f for f in os.listdir(img_folder) if f.lower().endswith('.png')])[:5]
if not img_files:
continue
fig, axes = plt.subplots(3, len(img_files), figsize=(10, 8))
fig.suptitle(f'{split.capitalize()} Set - First {len(img_files)} Samples', fontsize=16)
for i, img_file in enumerate(img_files):
mask_file = img_file
img = cv2.imread(os.path.join(img_folder, img_file))
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(os.path.join(mask_folder, mask_file))
mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
overlay = create_overlay(img_rgb, mask, class_colors)
axes[0, i].imshow(img_rgb)
axes[0, i].set_title(f'Image {i+1}', fontsize=10)
axes[0, i].axis('off')
axes[1, i].imshow(mask_rgb)
axes[1, i].set_title(f'Mask {i+1}', fontsize=10)
axes[1, i].axis('off')
axes[2, i].imshow(overlay)
axes[2, i].set_title(f'Overlay {i+1}', fontsize=10)
axes[2, i].axis('off')
plt.tight_layout()
plt.show()
print(f"Dataset organized in '{output_folder}' folder")
Most common mask colors (RGB): (np.uint8(0), np.uint8(0), np.uint8(0)) (np.uint8(128), np.uint8(0), np.uint8(0)) (np.uint8(0), np.uint8(128), np.uint8(0)) (np.uint8(1), np.uint8(0), np.uint8(0)) (np.uint8(2), np.uint8(0), np.uint8(0))
Dataset organized in '.' folder
In [ ]:
# https://github.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/blob/master/resunet%2B%2B_pytorch.py
# Note: This uses 3-channel input (in standard channel-first mode)
class Squeeze_Excitation(nn.Module):
def __init__(self, channel, r=8):
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.net = nn.Sequential(
nn.Linear(channel, channel // r, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // r, channel, bias=False),
nn.Sigmoid(),
)
def forward(self, inputs):
b, c, _, _ = inputs.shape
x = self.pool(inputs).view(b, c)
x = self.net(x).view(b, c, 1, 1)
x = inputs * x
return x
class Stem_Block(nn.Module):
def __init__(self, in_c, out_c, stride):
super().__init__()
self.c1 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
)
self.c2 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0),
nn.BatchNorm2d(out_c),
)
self.attn = Squeeze_Excitation(out_c)
def forward(self, inputs):
x = self.c1(inputs)
s = self.c2(inputs)
y = self.attn(x + s)
return y
class ResNet_Block(nn.Module):
def __init__(self, in_c, out_c, stride):
super().__init__()
self.c1 = nn.Sequential(
nn.BatchNorm2d(in_c),
nn.ReLU(),
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride),
nn.BatchNorm2d(out_c),
nn.ReLU(),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
)
self.c2 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0),
nn.BatchNorm2d(out_c),
)
self.attn = Squeeze_Excitation(out_c)
def forward(self, inputs):
x = self.c1(inputs)
s = self.c2(inputs)
y = self.attn(x + s)
return y
class ASPP(nn.Module):
def __init__(self, in_c, out_c, rate=[1, 6, 12, 18]):
super().__init__()
self.c1 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[0], padding=rate[0]),
nn.BatchNorm2d(out_c)
)
self.c2 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[1], padding=rate[1]),
nn.BatchNorm2d(out_c)
)
self.c3 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[2], padding=rate[2]),
nn.BatchNorm2d(out_c)
)
self.c4 = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[3], padding=rate[3]),
nn.BatchNorm2d(out_c)
)
self.c5 = nn.Conv2d(out_c, out_c, kernel_size=1, padding=0)
def forward(self, inputs):
x1 = self.c1(inputs)
x2 = self.c2(inputs)
x3 = self.c3(inputs)
x4 = self.c4(inputs)
x = x1 + x2 + x3 + x4
y = self.c5(x)
return y
class Attention_Block(nn.Module):
def __init__(self, in_c):
super().__init__()
out_c = in_c[1]
self.g_conv = nn.Sequential(
nn.BatchNorm2d(in_c[0]),
nn.ReLU(),
nn.Conv2d(in_c[0], out_c, kernel_size=3, padding=1),
nn.MaxPool2d((2, 2))
)
self.x_conv = nn.Sequential(
nn.BatchNorm2d(in_c[1]),
nn.ReLU(),
nn.Conv2d(in_c[1], out_c, kernel_size=3, padding=1),
)
self.gc_conv = nn.Sequential(
nn.BatchNorm2d(in_c[1]),
nn.ReLU(),
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
)
def forward(self, g, x):
g_pool = self.g_conv(g)
x_conv = self.x_conv(x)
gc_sum = g_pool + x_conv
gc_conv = self.gc_conv(gc_sum)
y = gc_conv * x
return y
class Decoder_Block(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.a1 = Attention_Block(in_c)
self.up = nn.Upsample(scale_factor=2, mode="nearest")
self.r1 = ResNet_Block(in_c[0]+in_c[1], out_c, stride=1)
def forward(self, g, x):
d = self.a1(g, x)
d = self.up(d)
d = torch.cat([d, g], axis=1)
d = self.r1(d)
return d
class ResUnetPlusPlus(nn.Module):
def __init__(self, input_ch, n_class):
super().__init__()
self.num_classes = n_class
self.c1 = Stem_Block(input_ch, 16, stride=1)
self.c2 = ResNet_Block(16, 32, stride=2)
self.c3 = ResNet_Block(32, 64, stride=2)
self.c4 = ResNet_Block(64, 128, stride=2)
self.b1 = ASPP(128, 256)
self.d1 = Decoder_Block([64, 256], 128)
self.d2 = Decoder_Block([32, 128], 64)
self.d3 = Decoder_Block([16, 64], 32)
self.aspp = ASPP(32, 16)
self.output = nn.Conv2d(16, n_class, kernel_size=1, padding=0)
def forward(self, inputs):
c1 = self.c1(inputs)
c2 = self.c2(c1)
c3 = self.c3(c2)
c4 = self.c4(c3)
b1 = self.b1(c4)
d1 = self.d1(c3, b1)
d2 = self.d2(c2, d1)
d3 = self.d3(c1, d2)
output = self.aspp(d3)
output = self.output(output)
if self.num_classes == 1:
output = torch.sigmoid(output)
elif self.num_classes > 1:
output = F.softmax(output, dim = 1)
return output
In [ ]:
# check the model structure
#CALLING THE MODEL
model = ResUnetPlusPlus(input_ch=3, n_class=3)
# model = nn.DataParallel(model)
model = model.to(device)
print(model)
ResUnetPlusPlus( (c1): Stem_Block( (c1): Sequential( (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=16, out_features=2, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=2, out_features=16, bias=False) (3): Sigmoid() ) ) ) (c2): ResNet_Block( (c1): Sequential( (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): ReLU() (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(2, 2)) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=32, out_features=4, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=4, out_features=32, bias=False) (3): Sigmoid() ) ) ) (c3): ResNet_Block( (c1): Sequential( (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): ReLU() (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=64, out_features=8, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=8, out_features=64, bias=False) (3): Sigmoid() ) ) ) (c4): ResNet_Block( (c1): Sequential( (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): ReLU() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2)) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=128, out_features=16, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=16, out_features=128, bias=False) (3): Sigmoid() ) ) ) (b1): ASPP( (c1): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c2): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c3): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c4): Sequential( (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18)) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) ) (d1): Decoder_Block( (a1): Attention_Block( (g_conv): Sequential( (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False) ) (x_conv): Sequential( (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gc_conv): Sequential( (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (up): Upsample(scale_factor=2.0, mode='nearest') (r1): ResNet_Block( (c1): Sequential( (0): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): ReLU() (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=128, out_features=16, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=16, out_features=128, bias=False) (3): Sigmoid() ) ) ) ) (d2): Decoder_Block( (a1): Attention_Block( (g_conv): Sequential( (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False) ) (x_conv): Sequential( (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gc_conv): Sequential( (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (up): Upsample(scale_factor=2.0, mode='nearest') (r1): ResNet_Block( (c1): Sequential( (0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(160, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): ReLU() (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1)) (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=64, out_features=8, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=8, out_features=64, bias=False) (3): Sigmoid() ) ) ) ) (d3): Decoder_Block( (a1): Attention_Block( (g_conv): Sequential( (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False) ) (x_conv): Sequential( (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gc_conv): Sequential( (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) ) (up): Upsample(scale_factor=2.0, mode='nearest') (r1): ResNet_Block( (c1): Sequential( (0): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (1): ReLU() (2): Conv2d(80, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): ReLU() (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (c2): Sequential( (0): Conv2d(80, 32, kernel_size=(1, 1), stride=(1, 1)) (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (attn): Squeeze_Excitation( (pool): AdaptiveAvgPool2d(output_size=1) (net): Sequential( (0): Linear(in_features=32, out_features=4, bias=False) (1): ReLU(inplace=True) (2): Linear(in_features=4, out_features=32, bias=False) (3): Sigmoid() ) ) ) ) (aspp): ASPP( (c1): Sequential( (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c2): Sequential( (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c3): Sequential( (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c4): Sequential( (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18)) (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (c5): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1)) ) (output): Conv2d(16, 3, kernel_size=(1, 1), stride=(1, 1)) )
Check the parameters of the model¶
In [ ]:
# check the model parameters
# channel-first input
summary(model, input_size=(3, 256, 256))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 16, 256, 256] 448 BatchNorm2d-2 [-1, 16, 256, 256] 32 ReLU-3 [-1, 16, 256, 256] 0 Conv2d-4 [-1, 16, 256, 256] 2,320 Conv2d-5 [-1, 16, 256, 256] 64 BatchNorm2d-6 [-1, 16, 256, 256] 32 AdaptiveAvgPool2d-7 [-1, 16, 1, 1] 0 Linear-8 [-1, 2] 32 ReLU-9 [-1, 2] 0 Linear-10 [-1, 16] 32 Sigmoid-11 [-1, 16] 0 Squeeze_Excitation-12 [-1, 16, 256, 256] 0 Stem_Block-13 [-1, 16, 256, 256] 0 BatchNorm2d-14 [-1, 16, 256, 256] 32 ReLU-15 [-1, 16, 256, 256] 0 Conv2d-16 [-1, 32, 128, 128] 4,640 BatchNorm2d-17 [-1, 32, 128, 128] 64 ReLU-18 [-1, 32, 128, 128] 0 Conv2d-19 [-1, 32, 128, 128] 9,248 Conv2d-20 [-1, 32, 128, 128] 544 BatchNorm2d-21 [-1, 32, 128, 128] 64 AdaptiveAvgPool2d-22 [-1, 32, 1, 1] 0 Linear-23 [-1, 4] 128 ReLU-24 [-1, 4] 0 Linear-25 [-1, 32] 128 Sigmoid-26 [-1, 32] 0 Squeeze_Excitation-27 [-1, 32, 128, 128] 0 ResNet_Block-28 [-1, 32, 128, 128] 0 BatchNorm2d-29 [-1, 32, 128, 128] 64 ReLU-30 [-1, 32, 128, 128] 0 Conv2d-31 [-1, 64, 64, 64] 18,496 BatchNorm2d-32 [-1, 64, 64, 64] 128 ReLU-33 [-1, 64, 64, 64] 0 Conv2d-34 [-1, 64, 64, 64] 36,928 Conv2d-35 [-1, 64, 64, 64] 2,112 BatchNorm2d-36 [-1, 64, 64, 64] 128 AdaptiveAvgPool2d-37 [-1, 64, 1, 1] 0 Linear-38 [-1, 8] 512 ReLU-39 [-1, 8] 0 Linear-40 [-1, 64] 512 Sigmoid-41 [-1, 64] 0 Squeeze_Excitation-42 [-1, 64, 64, 64] 0 ResNet_Block-43 [-1, 64, 64, 64] 0 BatchNorm2d-44 [-1, 64, 64, 64] 128 ReLU-45 [-1, 64, 64, 64] 0 Conv2d-46 [-1, 128, 32, 32] 73,856 BatchNorm2d-47 [-1, 128, 32, 32] 256 ReLU-48 [-1, 128, 32, 32] 0 Conv2d-49 [-1, 128, 32, 32] 147,584 Conv2d-50 [-1, 128, 32, 32] 8,320 BatchNorm2d-51 [-1, 128, 32, 32] 256 AdaptiveAvgPool2d-52 [-1, 128, 1, 1] 0 Linear-53 [-1, 16] 2,048 ReLU-54 [-1, 16] 0 Linear-55 [-1, 128] 2,048 Sigmoid-56 [-1, 128] 0 Squeeze_Excitation-57 [-1, 128, 32, 32] 0 ResNet_Block-58 [-1, 128, 32, 32] 0 Conv2d-59 [-1, 256, 32, 32] 295,168 BatchNorm2d-60 [-1, 256, 32, 32] 512 Conv2d-61 [-1, 256, 32, 32] 295,168 BatchNorm2d-62 [-1, 256, 32, 32] 512 Conv2d-63 [-1, 256, 32, 32] 295,168 BatchNorm2d-64 [-1, 256, 32, 32] 512 Conv2d-65 [-1, 256, 32, 32] 295,168 BatchNorm2d-66 [-1, 256, 32, 32] 512 Conv2d-67 [-1, 256, 32, 32] 65,792 ASPP-68 [-1, 256, 32, 32] 0 BatchNorm2d-69 [-1, 64, 64, 64] 128 ReLU-70 [-1, 64, 64, 64] 0 Conv2d-71 [-1, 256, 64, 64] 147,712 MaxPool2d-72 [-1, 256, 32, 32] 0 BatchNorm2d-73 [-1, 256, 32, 32] 512 ReLU-74 [-1, 256, 32, 32] 0 Conv2d-75 [-1, 256, 32, 32] 590,080 BatchNorm2d-76 [-1, 256, 32, 32] 512 ReLU-77 [-1, 256, 32, 32] 0 Conv2d-78 [-1, 256, 32, 32] 590,080 Attention_Block-79 [-1, 256, 32, 32] 0 Upsample-80 [-1, 256, 64, 64] 0 BatchNorm2d-81 [-1, 320, 64, 64] 640 ReLU-82 [-1, 320, 64, 64] 0 Conv2d-83 [-1, 128, 64, 64] 368,768 BatchNorm2d-84 [-1, 128, 64, 64] 256 ReLU-85 [-1, 128, 64, 64] 0 Conv2d-86 [-1, 128, 64, 64] 147,584 Conv2d-87 [-1, 128, 64, 64] 41,088 BatchNorm2d-88 [-1, 128, 64, 64] 256 AdaptiveAvgPool2d-89 [-1, 128, 1, 1] 0 Linear-90 [-1, 16] 2,048 ReLU-91 [-1, 16] 0 Linear-92 [-1, 128] 2,048 Sigmoid-93 [-1, 128] 0 Squeeze_Excitation-94 [-1, 128, 64, 64] 0 ResNet_Block-95 [-1, 128, 64, 64] 0 Decoder_Block-96 [-1, 128, 64, 64] 0 BatchNorm2d-97 [-1, 32, 128, 128] 64 ReLU-98 [-1, 32, 128, 128] 0 Conv2d-99 [-1, 128, 128, 128] 36,992 MaxPool2d-100 [-1, 128, 64, 64] 0 BatchNorm2d-101 [-1, 128, 64, 64] 256 ReLU-102 [-1, 128, 64, 64] 0 Conv2d-103 [-1, 128, 64, 64] 147,584 BatchNorm2d-104 [-1, 128, 64, 64] 256 ReLU-105 [-1, 128, 64, 64] 0 Conv2d-106 [-1, 128, 64, 64] 147,584 Attention_Block-107 [-1, 128, 64, 64] 0 Upsample-108 [-1, 128, 128, 128] 0 BatchNorm2d-109 [-1, 160, 128, 128] 320 ReLU-110 [-1, 160, 128, 128] 0 Conv2d-111 [-1, 64, 128, 128] 92,224 BatchNorm2d-112 [-1, 64, 128, 128] 128 ReLU-113 [-1, 64, 128, 128] 0 Conv2d-114 [-1, 64, 128, 128] 36,928 Conv2d-115 [-1, 64, 128, 128] 10,304 BatchNorm2d-116 [-1, 64, 128, 128] 128 AdaptiveAvgPool2d-117 [-1, 64, 1, 1] 0 Linear-118 [-1, 8] 512 ReLU-119 [-1, 8] 0 Linear-120 [-1, 64] 512 Sigmoid-121 [-1, 64] 0 Squeeze_Excitation-122 [-1, 64, 128, 128] 0 ResNet_Block-123 [-1, 64, 128, 128] 0 Decoder_Block-124 [-1, 64, 128, 128] 0 BatchNorm2d-125 [-1, 16, 256, 256] 32 ReLU-126 [-1, 16, 256, 256] 0 Conv2d-127 [-1, 64, 256, 256] 9,280 MaxPool2d-128 [-1, 64, 128, 128] 0 BatchNorm2d-129 [-1, 64, 128, 128] 128 ReLU-130 [-1, 64, 128, 128] 0 Conv2d-131 [-1, 64, 128, 128] 36,928 BatchNorm2d-132 [-1, 64, 128, 128] 128 ReLU-133 [-1, 64, 128, 128] 0 Conv2d-134 [-1, 64, 128, 128] 36,928 Attention_Block-135 [-1, 64, 128, 128] 0 Upsample-136 [-1, 64, 256, 256] 0 BatchNorm2d-137 [-1, 80, 256, 256] 160 ReLU-138 [-1, 80, 256, 256] 0 Conv2d-139 [-1, 32, 256, 256] 23,072 BatchNorm2d-140 [-1, 32, 256, 256] 64 ReLU-141 [-1, 32, 256, 256] 0 Conv2d-142 [-1, 32, 256, 256] 9,248 Conv2d-143 [-1, 32, 256, 256] 2,592 BatchNorm2d-144 [-1, 32, 256, 256] 64 AdaptiveAvgPool2d-145 [-1, 32, 1, 1] 0 Linear-146 [-1, 4] 128 ReLU-147 [-1, 4] 0 Linear-148 [-1, 32] 128 Sigmoid-149 [-1, 32] 0 Squeeze_Excitation-150 [-1, 32, 256, 256] 0 ResNet_Block-151 [-1, 32, 256, 256] 0 Decoder_Block-152 [-1, 32, 256, 256] 0 Conv2d-153 [-1, 16, 256, 256] 4,624 BatchNorm2d-154 [-1, 16, 256, 256] 32 Conv2d-155 [-1, 16, 256, 256] 4,624 BatchNorm2d-156 [-1, 16, 256, 256] 32 Conv2d-157 [-1, 16, 256, 256] 4,624 BatchNorm2d-158 [-1, 16, 256, 256] 32 Conv2d-159 [-1, 16, 256, 256] 4,624 BatchNorm2d-160 [-1, 16, 256, 256] 32 Conv2d-161 [-1, 16, 256, 256] 272 ASPP-162 [-1, 16, 256, 256] 0 Conv2d-163 [-1, 3, 256, 256] 51 ================================================================ Total params: 4,063,027 Trainable params: 4,063,027 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.75 Forward/backward pass size (MB): 893.51 Params size (MB): 15.50 Estimated Total Size (MB): 909.76 ----------------------------------------------------------------
Check the graph of the model¶
In [ ]:
# draw the computational graph of the model
model_graph1 = draw_graph(model, input_size=[(1, 3, 256, 256)], expand_nested=True)
model_graph1.visual_graph.render(format='png')
model_graph1.visual_graph
Out[ ]:
Datagenerator to query the input images and masks¶
In [ ]:
# Use the data generator to load the dataset
class DataGenerator(Dataset):
def __init__(self, image_list, masks_folder, num_classes=3):
self.files = image_list
self.masks_folder = masks_folder
self.num_classes = num_classes
# NUMBER OF FILES IN THE DATASET
def __len__(self):
return len(self.files)
# GETTING SINGLE PAIR OF DATA
def __getitem__(self, idx):
file_name = self.files[idx].split('/')[-1]
file_name = file_name[:-4]
mask_name = './{}/'.format(self.masks_folder) + file_name + '.png'
img = cv2.imread(self.files[idx], cv2.IMREAD_UNCHANGED)
if len(img.shape) == 2:
img = cv2.merge((img, img, img))
# Read mask as 3-channel for multi-class
mask = cv2.imread(mask_name, cv2.IMREAD_COLOR) # Read as BGR
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) # Convert to RGB
# Convert RGB mask to class indices (0, 1, 2)
# Assuming: Black (0,0,0) = background(0), Red-ish = class-1(1), Green-ish = class-2(2)
mask_class = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.long)
# Define color thresholds for each class
# Background: mostly black/dark
background_mask = (mask.sum(axis=2) < 50)
# Class 1: red-ish (R > G and R > B)
class1_mask = (mask[:,:,0] > mask[:,:,1]) & (mask[:,:,0] > mask[:,:,2]) & (mask[:,:,0] > 50)
# Class 2: green-ish (G > R and G > B)
class2_mask = (mask[:,:,1] > mask[:,:,0]) & (mask[:,:,1] > mask[:,:,2]) & (mask[:,:,1] > 50)
mask_class[background_mask] = 0
mask_class[class1_mask] = 1
mask_class[class2_mask] = 2
# Transpose image
img_transpose = np.transpose(img, (2, 0, 1))
# Normalize image to [0, 1]
img_normalized = img_transpose / 255.0
return torch.FloatTensor(img_normalized), torch.LongTensor(mask_class), str(file_name)
def load_data(image_list, masks_folder, batch_size=2, num_workers=10, shuffle=True, num_classes=3):
dataset = DataGenerator(image_list, masks_folder=masks_folder, num_classes=num_classes)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
return data_loader
Define the Dice-coefficient, Intersection over union metrics and the Categorical Cross-Entropy Loss function¶
In [ ]:
# Dice coefficient calculation for multi-class with class indices
def dice_coefficient(pred, target, num_classes, smooth=1e-6):
"""
Calculate dice coefficient for multi-class segmentation
pred: predicted probabilities [B, C, H, W] (after softmax)
target: ground truth class indices [B, H, W]
"""
dice_scores = []
# Convert predictions to class predictions
pred_classes = torch.argmax(pred, dim=1) # [B, H, W]
for class_idx in range(num_classes):
pred_class = (pred_classes == class_idx).float()
target_class = (target == class_idx).float()
intersection = (pred_class * target_class).sum(dim=(1, 2))
union = pred_class.sum(dim=(1, 2)) + target_class.sum(dim=(1, 2))
dice = (2. * intersection + smooth) / (union + smooth)
dice_scores.append(dice.mean())
return torch.stack(dice_scores), torch.stack(dice_scores).mean()
# IoU calculation for multi-class with class indices
def iou_coefficient(pred, target, num_classes, smooth=1e-6):
"""
Calculate IoU for multi-class segmentation
pred: predicted probabilities [B, C, H, W] (after softmax)
target: ground truth class indices [B, H, W]
"""
iou_scores = []
# Convert predictions to class predictions
pred_classes = torch.argmax(pred, dim=1) # [B, H, W]
for class_idx in range(num_classes):
pred_class = (pred_classes == class_idx).float()
target_class = (target == class_idx).float()
intersection = (pred_class * target_class).sum(dim=(1, 2))
union = pred_class.sum(dim=(1, 2)) + target_class.sum(dim=(1, 2)) - intersection
iou = (intersection + smooth) / (union + smooth)
iou_scores.append(iou.mean())
return torch.stack(iou_scores), torch.stack(iou_scores).mean()
# Categorical Cross Entropy Loss with optional class weights
class CategoryCrossEntropyLoss(nn.Module):
def __init__(self, weight=None):
super(CategoryCrossEntropyLoss, self).__init__()
self.criterion = nn.CrossEntropyLoss(weight=weight)
def forward(self, pred, target):
"""
pred: [B, C, H, W] - raw logits from model
target: [B, H, W] - class indices
"""
return self.criterion(pred, target)
We have three classes, with 3 colors, here, we define the overlay code to be used in future¶
In [ ]:
# Create color map for visualization
def create_color_map():
"""Create color map for 3 classes: background, class-1, class-2"""
color_map = np.array([
[0, 0, 0], # Background - Black
[0, 0, 128], # Class 1 - dark Red
[0, 128, 0], # Class 2 - dark Green
], dtype=np.uint8)
return color_map
def apply_color_map(mask, color_map):
"""Apply color map to single channel mask"""
h, w = mask.shape
colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
for class_idx, color in enumerate(color_map):
colored_mask[mask == class_idx] = color
return colored_mask
def create_overlay(image, mask, alpha=0.6):
"""Create overlay of image and colored mask"""
if len(image.shape) == 3 and image.shape[2] == 3:
# Image is already RGB
overlay = cv2.addWeighted(image.astype(np.uint8), 1-alpha, mask.astype(np.uint8), alpha, 0)
else:
# Convert grayscale to RGB if needed
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
overlay = cv2.addWeighted(image.astype(np.uint8), 1-alpha, mask.astype(np.uint8), alpha, 0)
return overlay
Define the function for loading the images by reading it one by one¶
In [ ]:
# sanity check
def get_image_address(image_data_folder, subfolder_name):
total_imgs = glob.glob(image_data_folder+subfolder_name+"/*.png")
error_counter = 0
TOTAL_COUNT = len(total_imgs)
image_address_list = []
for image_name in total_imgs:
try:
img = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
x = img.shape
image_address_list.append(image_name)
except:
print("Image not found : ",image_name)
error_counter+=1
print("Number of Files not found : ", error_counter)
print("Total Number of Files found : ", len(image_address_list))
return image_address_list
In [ ]:
# save checkpoint in pytorch
def save_ckp(checkpoint, checkpoint_path, save_after_epochs):
if checkpoint['epoch'] % save_after_epochs == 0:
torch.save(checkpoint, checkpoint_path)
# load checkpoint in pytorch
def load_ckp(checkpoint_path, model, model_opt):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model_opt.load_state_dict(checkpoint['optimizer'])
return model, model_opt, checkpoint['epoch']
Here, define the function for training the model, some of the core components of this section are:¶
model.train()
-> sets the model to update the weights (training mode)optimizer.zero_grad()
-> clear the gradients from the previous pass of the computational graph- Loss computation in the model using predicted masks and ground truth masks
loss.backward()
-> do the backward passoptimizer.step()
-> update the weights
The train
and val
functions save the training and validation metrics (i.e., average Dice and IoU of the training and validation dataset) as the training progresses.¶
In [ ]:
def train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes=3):
print("\n\n---------------------------------------------------------------------------------------------------------------\n")
progress_bar = tqdm(enumerate(train_loader))
total_loss = 0.0
total_dice = 0.0
total_iou = 0.0
model.train()
# Initialize loss function with class weights (optional - adjust based on your class imbalance)
class_weights = torch.FloatTensor([0.5, 1.0, 1.0]).to(device) # Less weight for background
criterion = CategoryCrossEntropyLoss(weight=class_weights)
for step, (inp__, gt__, file_name) in progress_bar:
#TRANSFERRING DATA TO DEVICE
inp__ = inp__.to(device)
gt__ = gt__.to(device)
# clear the gradient
optimizer.zero_grad()
#GETTING THE PREDICTED IMAGE (raw logits)
pred_logits = model.forward(inp__)
# Apply softmax for probability calculation (for metrics only)
pred_probs = torch.softmax(pred_logits, dim=1)
#CALCULATING LOSS (using raw logits)
loss = criterion(pred_logits, gt__)
#LOSS TAKEN INTO CONSIDERATION
total_loss += loss.item()
# Calculate metrics using probabilities
with torch.no_grad():
dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
total_dice += mean_dice.item()
total_iou += mean_iou.item()
#BACKPROPAGATING THE LOSS
loss.backward()
optimizer.step()
#DISPLAYING THE LOSS
progress_bar.set_description("Epoch: {} - Loss: {:.4f} - Dice: {:.4f} - IoU: {:.4f}".format(
epoch, loss.item(), mean_dice.item(), mean_iou.item()))
avg_loss = total_loss / len(train_loader)
avg_dice = total_dice / len(train_loader)
avg_iou = total_iou / len(train_loader)
with open("{}/train_logs.txt".format(hist_folder_name), "a") as text_file:
text_file.write("{} {:.6f} {:.6f} {:.6f}\n".format(epoch, avg_loss, avg_dice, avg_iou))
print(Fore.GREEN+"Training Epoch: {} | Loss: {:.4f} | Dice: {:.4f} | IoU: {:.4f}".format(
epoch, avg_loss, avg_dice, avg_iou)+Style.RESET_ALL)
return model, optimizer
Here, we will define the validation epoch, we will set the model to model.eval()
mode, so that there is no updation of weights, no loss computation is done here¶
In [ ]:
def val_epoch(val_loader, model, optimizer, epoch, hist_folder_name, num_classes=3):
model.eval()
progress_bar = tqdm(enumerate(val_loader))
total_loss = 0.0
total_dice = 0.0
total_iou = 0.0
class_weights = torch.FloatTensor([0.5, 1.0, 1.0]).to(device)
criterion = CategoryCrossEntropyLoss(weight=class_weights)
for step, (inp__, gt__, file_name) in progress_bar:
inp__ = inp__.to(device)
gt__ = gt__.to(device)
#PREDICTED IMAGE
pred_logits = model.forward(inp__)
pred_probs = torch.softmax(pred_logits, dim=1)
#CALCULATING LOSSES
loss = criterion(pred_logits, gt__)
total_loss += loss.item()
# Calculate metrics
dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
total_dice += mean_dice.item()
total_iou += mean_iou.item()
progress_bar.set_description("Val Epoch: {} - Loss: {:.4f} - Dice: {:.4f} - IoU: {:.4f}".format(
epoch, loss.item(), mean_dice.item(), mean_iou.item()))
avg_loss = total_loss / len(val_loader)
avg_dice = total_dice / len(val_loader)
avg_iou = total_iou / len(val_loader)
with open("{}/val_logs.txt".format(hist_folder_name), "a") as text_file:
text_file.write("{} {:.6f} {:.6f} {:.6f}\n".format(epoch, avg_loss, avg_dice, avg_iou))
print(Fore.BLUE+"Validation Epoch: {} | Loss: {:.4f} | Dice: {:.4f} | IoU: {:.4f}".format(
epoch, avg_loss, avg_dice, avg_iou)+Style.RESET_ALL)
Function for the test epoch -> for sanity check¶
- Here we will test the model on the test dataset after the model has been trained, we will dump some of the images and see how the model is performing by overlaying the predicted masks on the images
In [ ]:
def test_epoch(test_loader, model, optimizer, epoch, inf_folder_name, hist_folder_name, num_classes=3):
model.eval()
progress_bar = tqdm(enumerate(test_loader))
total_loss = 0.0
total_dice = 0.0
total_iou = 0.0
#SETTING THE NUMBER OF IMAGES TO CHECK AFTER EACH ITERATION
no_img_to_write = 20
class_weights = torch.FloatTensor([0.5, 1.0, 1.0]).to(device)
criterion = CategoryCrossEntropyLoss(weight=class_weights)
# Get color map for visualization
color_map = create_color_map()
for step, (inp__, gt__, file_name) in progress_bar:
inp__ = inp__.to(device)
gt__ = gt__.to(device)
#PREDICTED IMAGE
pred_logits = model.forward(inp__)
pred_probs = torch.softmax(pred_logits, dim=1)
#CALCULATING LOSSES
loss = criterion(pred_logits, gt__)
total_loss += loss.item()
# Calculate metrics
dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
total_dice += mean_dice.item()
total_iou += mean_iou.item()
progress_bar.set_description("Test Epoch: {} - Loss: {:.4f} - Dice: {:.4f} - IoU: {:.4f}".format(
epoch, loss.item(), mean_dice.item(), mean_iou.item()))
#WRITING THE IMAGES INTO THE SPECIFIED DIRECTORY
if(step < no_img_to_write):
pred_classes = torch.argmax(pred_probs, dim=1) # Get class predictions
p_img = pred_classes.cpu().numpy() # [B, H, W]
gt_img = gt__.cpu().numpy() # [B, H, W]
inp_img = inp__.cpu().numpy() # [B, C, H, W]
#FOLDER PATH TO WRITE THE INFERENCES
inference_folder = "{}".format(inf_folder_name)
if not os.path.isdir(inference_folder):
os.mkdir(inference_folder)
print("\n Saving inferences at epoch === ",epoch)
# Save inference images for multi-class with overlays
for batch_idx, (p_image_loop, gt_img_loop, inp_img_loop) in enumerate(zip(p_img, gt_img, inp_img)):
# Prepare input image for overlay
inp_img_display = np.transpose(inp_img_loop, (1, 2, 0)) # [H, W, C]
inp_img_display = (inp_img_display * 255.0).astype(np.uint8)
# 1. Save raw input image
cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_input.png"),
cv2.cvtColor(inp_img_display, cv2.COLOR_RGB2BGR))
# 2. Create and save colored prediction mask
pred_colored = apply_color_map(p_image_loop, color_map)
cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_pred_mask.png"),
cv2.cvtColor(pred_colored, cv2.COLOR_RGB2BGR))
# 3. Create and save colored ground truth mask
gt_colored = apply_color_map(gt_img_loop, color_map)
cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_gt_mask.png"),
cv2.cvtColor(gt_colored, cv2.COLOR_RGB2BGR))
# 4. Create and save prediction overlay
pred_overlay = create_overlay(inp_img_display, pred_colored, alpha=0.4)
cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_pred_overlay.png"),
cv2.cvtColor(pred_overlay, cv2.COLOR_RGB2BGR))
# 5. Create and save ground truth overlay
gt_overlay = create_overlay(inp_img_display, gt_colored, alpha=0.4)
cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_gt_overlay.png"),
cv2.cvtColor(gt_overlay, cv2.COLOR_RGB2BGR))
# 6. Create side-by-side comparison
comparison = np.hstack([
inp_img_display,
gt_colored,
pred_colored,
gt_overlay,
pred_overlay
])
cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_comparison.png"),
cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))
avg_loss = total_loss / len(test_loader)
avg_dice = total_dice / len(test_loader)
avg_iou = total_iou / len(test_loader)
with open("{}/test_logs.txt".format(hist_folder_name), "a") as text_file:
text_file.write("{} {:.6f} {:.6f} {:.6f}\n".format(epoch, avg_loss, avg_dice, avg_iou))
print(Fore.RED+"Test Epoch: {} | Loss: {:.4f} | Dice: {:.4f} | IoU: {:.4f}".format(
epoch, avg_loss, avg_dice, avg_iou)+Style.RESET_ALL)
print("---------------------------------------------------------------------------------------------------------------")
Master epoch to control the train, validation and test epoch. Here, we can save the checkpoint after each epoch or after a specified number of epochs, resume training or start training from scratch. Also, this section creates all the folders for dumping the inferences and history/metrics that are recorded during training¶
In [ ]:
def train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume, model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes=3):
checkpoint_folder_name = "checkpoint"
Path(checkpoint_folder_name).mkdir(parents=True, exist_ok=True)
Path(inf_folder_name).mkdir(parents=True, exist_ok=True)
Path(hist_folder_name).mkdir(parents=True, exist_ok=True)
epoch = 0
#PATH TO SAVE THE CHECKPOINT
checkpoint_path = "checkpoint/{}_{}.pt".format(model_name,epoch)
#IF TRAINING IS TO RESUMED FROM A CERTAIN CHECKPOINT
if resume:
model, optimizer, epoch = load_ckp(
checkpoint_path, model, optimizer)
while epoch <= n_epoch:
checkpoint_path = "checkpoint/{}_{}.pt".format(model_name,epoch)
epoch += 1
model, optimizer = train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes)
#CHECKPOINT CREATION
checkpoint = {'epoch': epoch+1, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()}
#CHECKPOINT SAVING
save_ckp(checkpoint, checkpoint_path, save_after_epochs)
print(Fore.BLACK+"Checkpoint Saved"+Style.RESET_ALL)
with torch.no_grad():
val_epoch(val_loader, model, optimizer, epoch, hist_folder_name, num_classes)
print("************************ Final Test Epoch *****************************")
with torch.no_grad():
test_epoch(test_loader, model, optimizer, epoch, inf_folder_name, hist_folder_name, num_classes)
The main function, where all the training is done. We can pass all the arguments using argsparser, or define a class to store the hyper-parameters and values which will be necessary for training. We load the train, validation and test image-mask pairs and the dataloaders and then call the master epoch for performing training, validation and testing.¶
In [ ]:
def main():
# Set your parameters directly here for Jupyter notebook
class Args:
def __init__(self):
# Training parameters
self.lr = "0.001"
self.batch_size = "4"
self.num_epochs = "2"
self.num_classes = 3
# Dataset folders
self.train_images = "train_images"
self.val_images = "val_images"
self.test_images = "test_images"
self.train_masks = "train_masks"
self.val_masks = "val_masks"
self.test_masks = "test_masks"
# Output folders
self.history_folder_name = "history"
self.inference_folder_name = "inference"
self.chkpt_name = "multiclass_model"
self.save_after_epoch = "1"
args = Args()
train_image_address_list = get_image_address("./", args.train_images)
print("Total Number of Training Images : ", len(train_image_address_list))
val_image_address_list = get_image_address("./", args.val_images)
test_image_address_list = get_image_address("./", args.test_images)
train_masks = args.train_masks
val_masks = args.val_masks
test_masks = args.test_masks
save_after_epochs = int(args.save_after_epoch)
num_classes = args.num_classes
# CREATING THE TRAIN LOADER
train_loader = load_data(
train_image_address_list, masks_folder=train_masks, batch_size=int(args.batch_size),
num_workers=8, shuffle=True, num_classes=num_classes)
val_loader = load_data(
val_image_address_list, masks_folder=val_masks, batch_size=int(args.batch_size),
num_workers=8, shuffle=True, num_classes=num_classes)
# CREATING THE TEST LOADER
test_loader = load_data(
test_image_address_list, masks_folder=test_masks, batch_size=1,
num_workers=8, shuffle=False, num_classes=num_classes)
# CALLING THE MODEL - Make sure your model outputs num_classes channels
model = ResUnetPlusPlus(input_ch=3, n_class=num_classes) # Change this to match your model
# model = nn.DataParallel(model)
model = model.to(device)
# DEFINING THE OPTIMIZER
optimizer = optim.Adam(
[p for p in model.parameters() if p.requires_grad], lr=float(args.lr), weight_decay=5e-4)
n_epoch = int(args.num_epochs)
# INDICATOR VARIABLE TO RESUME TRAINING OR START AFRESH
resume = False
model_name = args.chkpt_name
inf_folder_name = args.inference_folder_name
hist_folder_name = args.history_folder_name
train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume,
model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes)
if __name__ == "__main__":
main()
Number of Files not found : 0 Total Number of Files found : 4723 Total Number of Training Images : 4723 Number of Files not found : 0 Total Number of Files found : 787 Number of Files not found : 0 Total Number of Files found : 2362
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(
---------------------------------------------------------------------------------------------------------------
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn( Epoch: 1 - Loss: 0.9149 - Dice: 0.5427 - IoU: 0.4689: : 54it [00:10, 4.93it/s]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) /tmp/ipython-input-893361213.py in <cell line: 0>() 73 74 if __name__ == "__main__": ---> 75 main() /tmp/ipython-input-893361213.py in main() 69 hist_folder_name = args.history_folder_name 70 ---> 71 train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume, 72 model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes) 73 /tmp/ipython-input-3502321866.py in train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume, model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes) 18 checkpoint_path = "checkpoint/{}_{}.pt".format(model_name,epoch) 19 epoch += 1 ---> 20 model, optimizer = train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes) 21 22 #CHECKPOINT CREATION /tmp/ipython-input-4030073634.py in train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes) 31 32 #LOSS TAKEN INTO CONSIDERATION ---> 33 total_loss += loss.item() 34 35 # Calculate metrics using probabilities KeyboardInterrupt:
Download the pre-trained models and the history (which is trained for 1000 epochs) for loading the best models¶
In [ ]:
! wget https://jimut123.github.io/talks/2025_dlnlp/multiclass_model_998.pt
! wget https://jimut123.github.io/talks/2025_dlnlp/history/train_logs.txt
! wget https://jimut123.github.io/talks/2025_dlnlp/history/val_logs.txt
! wget https://jimut123.github.io/talks/2025_dlnlp/history/test_logs.txt
--2025-09-14 10:56:19-- https://jimut123.github.io/talks/2025_dlnlp/multiclass_model_998.pt Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ... Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 49032778 (47M) [application/octet-stream] Saving to: ‘multiclass_model_998.pt’ multiclass_model_99 100%[===================>] 46.76M 273MB/s in 0.2s 2025-09-14 10:56:22 (273 MB/s) - ‘multiclass_model_998.pt’ saved [49032778/49032778] --2025-09-14 10:56:22-- https://jimut123.github.io/talks/2025_dlnlp/history/train_logs.txt Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ... Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 30957 (30K) [text/plain] Saving to: ‘train_logs.txt’ train_logs.txt 100%[===================>] 30.23K --.-KB/s in 0.001s 2025-09-14 10:56:23 (47.1 MB/s) - ‘train_logs.txt’ saved [30957/30957] --2025-09-14 10:56:23-- https://jimut123.github.io/talks/2025_dlnlp/history/val_logs.txt Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ... Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 30957 (30K) [text/plain] Saving to: ‘val_logs.txt’ val_logs.txt 100%[===================>] 30.23K --.-KB/s in 0.001s 2025-09-14 10:56:23 (52.9 MB/s) - ‘val_logs.txt’ saved [30957/30957] --2025-09-14 10:56:23-- https://jimut123.github.io/talks/2025_dlnlp/history/test_logs.txt Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ... Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 32 [text/plain] Saving to: ‘test_logs.txt’ test_logs.txt 100%[===================>] 32 --.-KB/s in 0s 2025-09-14 10:56:24 (1.90 MB/s) - ‘test_logs.txt’ saved [32/32]
In [ ]:
def load_log_file(filename):
"""Load log file and return arrays for epoch, loss, dice, iou"""
data = np.loadtxt(filename)
epochs = data[:, 0]
loss = data[:, 1]
dice = data[:, 2]
iou = data[:, 3]
return epochs, loss, dice, iou
def plot_metrics(train_epochs, train_loss, train_dice, train_iou,
val_epochs, val_loss, val_dice, val_iou,
save_path='training_history.png'):
"""Plot training and validation metrics"""
# Create figure with subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Training History Metrics', fontsize=16, fontweight='bold', y=0.98)
# Plot Loss
ax1.plot(train_epochs, train_loss, 'b-', linewidth=2, label='Training Loss', alpha=0.8)
ax1.plot(val_epochs, val_loss, 'r-', linewidth=2, label='Validation Loss', alpha=0.8)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Categorical CrossEntropy Loss')
ax1.set_title('Loss Curve')
ax1.legend()
ax1.grid(True, alpha=0.3)
# Plot Dice Coefficient
ax2.plot(train_epochs, train_dice, 'b-', linewidth=2, label='Training Dice', alpha=0.8)
ax2.plot(val_epochs, val_dice, 'r-', linewidth=2, label='Validation Dice', alpha=0.8)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Dice Coefficient')
ax2.set_title('Dice Coefficient')
ax2.legend()
ax2.grid(True, alpha=0.3)
# Plot IoU
ax3.plot(train_epochs, train_iou, 'b-', linewidth=2, label='Training IoU', alpha=0.8)
ax3.plot(val_epochs, val_iou, 'r-', linewidth=2, label='Validation IoU', alpha=0.8)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('IoU')
ax3.set_title('Intersection over Union (IoU)')
ax3.legend()
ax3.grid(True, alpha=0.3)
# Plot all metrics together for comparison
ax4.plot(train_epochs, train_loss, 'b-', linewidth=2, label='Train Loss', alpha=0.8)
ax4.plot(train_epochs, train_dice, 'g-', linewidth=2, label='Train Dice', alpha=0.8)
ax4.plot(train_epochs, train_iou, 'm-', linewidth=2, label='Train IoU', alpha=0.8)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Metric Value')
ax4.set_title('Training Metrics Comparison')
ax4.legend()
ax4.grid(True, alpha=0.3)
# Adjust layout and save
plt.tight_layout()
plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
plt.show()
plt.close()
print(f"Plot saved as {save_path}")
In [ ]:
# Load training and validation logs
try:
train_epochs, train_loss, train_dice, train_iou = load_log_file('train_logs.txt')
val_epochs, val_loss, val_dice, val_iou = load_log_file('val_logs.txt')
# Plot and save the metrics
plot_metrics(train_epochs, train_loss, train_dice, train_iou,
val_epochs, val_loss, val_dice, val_iou,
save_path='training_history.png')
except FileNotFoundError as e:
print(f"Error: {e}")
print("Please make sure both 'train_logs.txt' and 'val_logs.txt' exist in the current directory.")
except Exception as e:
print(f"Error loading files: {e}")
Plot saved as training_history.png
In this function, we go through the image-samples from the test folder and dump them in test_dumps
folder. We also store the overlays of the predicted masks and the ground-truth masks in the test images and save them. Also returns the list of Dice and IoU metrics to be used later¶
In [ ]:
def test_dump_with_checkpoint(test_loader, checkpoint_path, dump_folder="test_dumps", num_classes=3):
"""
Load checkpoint and dump test images with predictions and overlays
"""
# Create dump folder
Path(dump_folder).mkdir(parents=True, exist_ok=True)
# Load model and checkpoint
model = ResUnetPlusPlus(input_ch=3, n_class=num_classes)
model = model.to(device)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
print(f"Loaded checkpoint from epoch: {checkpoint.get('epoch', 'Unknown')}")
model.eval()
progress_bar = tqdm(enumerate(test_loader))
# Get color map for visualization
color_map = create_color_map()
# Store metrics for each image
image_metrics = []
with torch.no_grad():
for step, (inp__, gt__, file_name) in progress_bar:
inp__ = inp__.to(device)
gt__ = gt__.to(device)
# Get predictions
pred_logits = model.forward(inp__)
pred_probs = torch.softmax(pred_logits, dim=1)
pred_classes = torch.argmax(pred_probs, dim=1)
# Calculate metrics for this batch
dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
# Convert to numpy
p_img = pred_classes.cpu().numpy()
gt_img = gt__.cpu().numpy()
inp_img = inp__.cpu().numpy()
progress_bar.set_description(f"Processing image {step+1} - Dice: {mean_dice.item():.4f}")
# Save images for each item in batch
for batch_idx, (p_image_loop, gt_img_loop, inp_img_loop) in enumerate(zip(p_img, gt_img, inp_img)):
# Store metrics for this image
current_dice_scores = dice_per_class.cpu().numpy()
current_iou_scores = iou_per_class.cpu().numpy()
image_info = {
'filename': str(file_name[batch_idx]),
'dice_class1': current_dice_scores[1], # Class 1 dice
'dice_class2': current_dice_scores[2], # Class 2 dice
'dice_mean': mean_dice.item(),
'iou_class1': current_iou_scores[1],
'iou_class2': current_iou_scores[2],
'iou_mean': mean_iou.item()
}
image_metrics.append(image_info)
# Prepare input image for overlay
inp_img_display = np.transpose(inp_img_loop, (1, 2, 0))
inp_img_display = (inp_img_display * 255.0).astype(np.uint8)
# Create colored masks
pred_colored = apply_color_map(p_image_loop, color_map)
gt_colored = apply_color_map(gt_img_loop, color_map)
# Create overlays
pred_overlay = create_overlay(inp_img_display, pred_colored, alpha=0.4)
gt_overlay = create_overlay(inp_img_display, gt_colored, alpha=0.4)
# Save all images
base_name = str(file_name[batch_idx])
cv2.imwrite(os.path.join(dump_folder, f"{base_name}_input.png"),
cv2.cvtColor(inp_img_display, cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(dump_folder, f"{base_name}_pred_mask.png"),
cv2.cvtColor(pred_colored, cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(dump_folder, f"{base_name}_gt_mask.png"),
cv2.cvtColor(gt_colored, cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(dump_folder, f"{base_name}_pred_overlay.png"),
cv2.cvtColor(pred_overlay, cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(dump_folder, f"{base_name}_gt_overlay.png"),
cv2.cvtColor(gt_overlay, cv2.COLOR_RGB2BGR))
# Create side-by-side comparison
comparison = np.hstack([inp_img_display, gt_colored, pred_colored, gt_overlay, pred_overlay])
cv2.imwrite(os.path.join(dump_folder, f"{base_name}_comparison.png"),
cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))
print(f"\nCompleted test dump. Images saved in: {dump_folder}")
return image_metrics
This function calculates the mean and standard deviation of predicted masks based on the ground truth masks¶
In [ ]:
def calculate_class_metrics(image_metrics):
"""
Calculate mean ± std for dice coefficients of class 1 and class 2
"""
dice_class1 = [img['dice_class1'] for img in image_metrics]
dice_class2 = [img['dice_class2'] for img in image_metrics]
# Calculate statistics
dice_c1_mean = np.mean(dice_class1)
dice_c1_std = np.std(dice_class1)
dice_c2_mean = np.mean(dice_class2)
dice_c2_std = np.std(dice_class2)
print("\n" + "="*60)
print("DICE COEFFICIENT RESULTS")
print("="*60)
print(f"Class 1 Dice: {dice_c1_mean:.4f} ± {dice_c1_std:.4f}")
print(f"Class 2 Dice: {dice_c2_mean:.4f} ± {dice_c2_std:.4f}")
print("="*60)
# Also calculate overall metrics
overall_dice = [img['dice_mean'] for img in image_metrics]
iou_class1 = [img['iou_class1'] for img in image_metrics]
iou_class2 = [img['iou_class2'] for img in image_metrics]
overall_iou = [img['iou_mean'] for img in image_metrics]
print(f"Overall Dice: {np.mean(overall_dice):.4f} ± {np.std(overall_dice):.4f}\n\n")
print(f"Class 1 IoU: {np.mean(iou_class1):.4f} ± {np.std(iou_class1):.4f}")
print(f"Class 2 IoU: {np.mean(iou_class2):.4f} ± {np.std(iou_class2):.4f}")
print(f"Overall IoU: {np.mean(overall_iou):.4f} ± {np.std(overall_iou):.4f}")
return {
'dice_class1': {'mean': dice_c1_mean, 'std': dice_c1_std},
'dice_class2': {'mean': dice_c2_mean, 'std': dice_c2_std},
'overall_dice': {'mean': np.mean(overall_dice), 'std': np.std(overall_dice)}
}
This function displays the best performing image-prediction overlays and the worst performing image-prediction overlays based on the prediction of model.¶
In [ ]:
def display_best_worst_images(image_metrics, dump_folder, top_k=5, criterion='dice_mean'):
"""
Display the best and worst performing images based on specified criterion
"""
# Sort images by performance
sorted_metrics = sorted(image_metrics, key=lambda x: x[criterion])
worst_images = sorted_metrics[:top_k] # Lowest scores
best_images = sorted_metrics[-top_k:] # Highest scores
best_images.reverse() # Show best first
print(f"\n" + "="*80)
print(f"TOP {top_k} BEST PERFORMING IMAGES (by {criterion})")
print("="*80)
fig, axes = plt.subplots(top_k, 5, figsize=(20, 4*top_k))
if top_k == 1:
axes = axes.reshape(1, -1)
for idx, img_info in enumerate(best_images):
filename = img_info['filename']
score = img_info[criterion]
print(f"{idx+1}. {filename} - {criterion}: {score:.4f}")
print(f" Class 1 Dice: {img_info['dice_class1']:.4f}, Class 2 Dice: {img_info['dice_class2']:.4f}")
# Load and display images
input_img = cv2.imread(os.path.join(dump_folder, f"{filename}_input.png"))
gt_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_mask.png"))
pred_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_mask.png"))
gt_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_overlay.png"))
pred_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_overlay.png"))
# Convert BGR to RGB for matplotlib
images = [input_img, gt_mask, pred_mask, gt_overlay, pred_overlay]
titles = ['Input', 'GT Mask', 'Pred Mask', 'GT Overlay', 'Pred Overlay']
for j, (img, title) in enumerate(zip(images, titles)):
if img is not None:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
axes[idx, j].imshow(img_rgb)
axes[idx, j].set_title(f'{title}\n{criterion}: {score:.3f}')
axes[idx, j].axis('off')
plt.suptitle(f'Best {top_k} Images by {criterion}', fontsize=16, y=0.98)
plt.tight_layout()
plt.show()
print(f"\n" + "="*80)
print(f"TOP {top_k} WORST PERFORMING IMAGES (by {criterion})")
print("="*80)
fig, axes = plt.subplots(top_k, 5, figsize=(20, 4*top_k))
if top_k == 1:
axes = axes.reshape(1, -1)
for idx, img_info in enumerate(worst_images):
filename = img_info['filename']
score = img_info[criterion]
print(f"{idx+1}. {filename} - {criterion}: {score:.4f}")
print(f" Class 1 Dice: {img_info['dice_class1']:.4f}, Class 2 Dice: {img_info['dice_class2']:.4f}")
# Load and display images
input_img = cv2.imread(os.path.join(dump_folder, f"{filename}_input.png"))
gt_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_mask.png"))
pred_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_mask.png"))
gt_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_overlay.png"))
pred_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_overlay.png"))
# Convert BGR to RGB for matplotlib
images = [input_img, gt_mask, pred_mask, gt_overlay, pred_overlay]
titles = ['Input', 'GT Mask', 'Pred Mask', 'GT Overlay', 'Pred Overlay']
for j, (img, title) in enumerate(zip(images, titles)):
if img is not None:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
axes[idx, j].imshow(img_rgb)
axes[idx, j].set_title(f'{title}\n{criterion}: {score:.3f}')
axes[idx, j].axis('off')
plt.suptitle(f'Worst {top_k} Images by {criterion}', fontsize=16, y=0.98)
plt.tight_layout()
plt.show()
Shows the class wise and overall Dice and IoU of the predicted masks¶
In [ ]:
# Parameters - load from the saved checkpoint and mention the dump folder name
checkpoint_path = "multiclass_model_998.pt"
dump_folder = "test_dumps"
num_classes = 3
# Get test image addresses
test_image_address_list = get_image_address("./", "test_images")
print(f"Total Number of Test Images: {len(test_image_address_list)}")
# Create test loader
test_loader = load_data(
test_image_address_list, masks_folder="test_masks", batch_size=1,
num_workers=8, shuffle=False, num_classes=num_classes)
# Run test dump
print("Starting test dump with checkpoint loading...")
image_metrics = test_dump_with_checkpoint(
test_loader, checkpoint_path, dump_folder, num_classes)
# Calculate and display metrics
print("Calculating class-wise metrics...")
metrics_summary = calculate_class_metrics(image_metrics)
Number of Files not found : 0 Total Number of Files found : 2362 Total Number of Test Images: 2362 Starting test dump with checkpoint loading... Loaded checkpoint from epoch: 1000
Processing image 2362 - Dice: 0.7259: : 2362it [02:44, 14.38it/s]
Completed test dump. Images saved in: test_dumps Calculating class-wise metrics... ============================================================ DICE COEFFICIENT RESULTS ============================================================ Class 1 Dice: 0.8544 ± 0.2623 Class 2 Dice: 0.9306 ± 0.2090 ============================================================ Overall Dice: 0.9227 ± 0.1426 Class 1 IoU: 0.8077 ± 0.2738 Class 2 IoU: 0.9136 ± 0.2248 Overall IoU: 0.8964 ± 0.1586
In [ ]:
# Display best and worst images
print("Displaying best and worst performing images...")
display_best_worst_images(image_metrics, dump_folder, top_k=3, criterion='dice_mean')
Displaying best and worst performing images... ================================================================================ TOP 3 BEST PERFORMING IMAGES (by dice_mean) ================================================================================ 1. 0129_5 - dice_mean: 0.9991 Class 1 Dice: 1.0000, Class 2 Dice: 0.9983 2. 0654_1 - dice_mean: 0.9989 Class 1 Dice: 1.0000, Class 2 Dice: 0.9987 3. 0573_0 - dice_mean: 0.9988 Class 1 Dice: 1.0000, Class 2 Dice: 0.9977
================================================================================ TOP 3 WORST PERFORMING IMAGES (by dice_mean) ================================================================================ 1. 0136_5 - dice_mean: 0.2355 Class 1 Dice: 0.0000, Class 2 Dice: 0.0500 2. 0136_3 - dice_mean: 0.2536 Class 1 Dice: 0.0000, Class 2 Dice: 0.0734 3. 0136_2 - dice_mean: 0.2782 Class 1 Dice: 0.0000, Class 2 Dice: 0.1701
In [ ]:
# display based on specific class performance
print("\nDisplaying based on Class 1 Dice performance...")
display_best_worst_images(image_metrics, dump_folder, top_k=3, criterion='dice_class1')
Displaying based on Class 1 Dice performance... ================================================================================ TOP 3 BEST PERFORMING IMAGES (by dice_class1) ================================================================================ 1. 0347_1 - dice_class1: 1.0000 Class 1 Dice: 1.0000, Class 2 Dice: 0.9886 2. 0577_5 - dice_class1: 1.0000 Class 1 Dice: 1.0000, Class 2 Dice: 0.9909 3. 0283_4 - dice_class1: 1.0000 Class 1 Dice: 1.0000, Class 2 Dice: 0.9908
================================================================================ TOP 3 WORST PERFORMING IMAGES (by dice_class1) ================================================================================ 1. 0394_4 - dice_class1: 0.0000 Class 1 Dice: 0.0000, Class 2 Dice: 0.7643 2. 0024_5 - dice_class1: 0.0000 Class 1 Dice: 0.0000, Class 2 Dice: 0.8337 3. 0653_4 - dice_class1: 0.0000 Class 1 Dice: 0.0000, Class 2 Dice: 0.9222
In [ ]:
print("\nDisplaying based on Class 2 Dice performance...")
display_best_worst_images(image_metrics, dump_folder, top_k=3, criterion='dice_class2')
Displaying based on Class 2 Dice performance... ================================================================================ TOP 3 BEST PERFORMING IMAGES (by dice_class2) ================================================================================ 1. 0202_5 - dice_class2: 1.0000 Class 1 Dice: 0.9678, Class 2 Dice: 1.0000 2. 0173_4 - dice_class2: 1.0000 Class 1 Dice: 0.9186, Class 2 Dice: 1.0000 3. 0775_1 - dice_class2: 1.0000 Class 1 Dice: 0.9272, Class 2 Dice: 1.0000
================================================================================ TOP 3 WORST PERFORMING IMAGES (by dice_class2) ================================================================================ 1. 0424_2 - dice_class2: 0.0000 Class 1 Dice: 0.4869, Class 2 Dice: 0.0000 2. 0061_5 - dice_class2: 0.0000 Class 1 Dice: 1.0000, Class 2 Dice: 0.0000 3. 0221_5 - dice_class2: 0.0000 Class 1 Dice: 0.3215, Class 2 Dice: 0.0000
We see the kernel patterns and the output feature maps for some of the layers¶
In [ ]:
def visualize_conv_kernels(model, layer_name, save_folder="kernel_patterns"):
"""
Visualize and save convolutional kernel patterns from a specific layer
"""
Path(save_folder).mkdir(parents=True, exist_ok=True)
# Get the layer by name
layer = dict(model.named_modules())[layer_name]
if not isinstance(layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
print(f"Layer {layer_name} is not a convolutional layer")
return
# Get kernel weights [out_channels, in_channels, kernel_h, kernel_w]
kernels = layer.weight.data.cpu().numpy()
print(f"Layer: {layer_name}")
print(f"Kernel shape: {kernels.shape}")
# Number of output channels (filters)
num_filters = kernels.shape[0]
num_input_channels = kernels.shape[1]
# Create subplots - show first 64 filters max
max_filters = min(64, num_filters)
cols = 8
rows = (max_filters + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(16, 2*rows))
if rows == 1:
axes = axes.reshape(1, -1)
for i in range(max_filters):
row = i // cols
col = i % cols
# For multi-channel input, take mean across input channels
if num_input_channels > 1:
kernel_img = np.mean(kernels[i], axis=0)
else:
kernel_img = kernels[i, 0]
# Normalize to [0, 1] for visualization
kernel_img = (kernel_img - kernel_img.min()) / (kernel_img.max() - kernel_img.min() + 1e-8)
axes[row, col].imshow(kernel_img, cmap='viridis', interpolation='nearest')
axes[row, col].set_title(f'Filter {i}', fontsize=8)
axes[row, col].axis('off')
# Hide unused subplots
for i in range(max_filters, rows * cols):
row = i // cols
col = i % cols
axes[row, col].axis('off')
plt.suptitle(f'Kernels from {layer_name}', fontsize=14)
plt.tight_layout()
plt.savefig(f'{save_folder}/{layer_name.replace(".", "_")}_kernels.png', dpi=300, bbox_inches='tight')
plt.show()
return kernels
def visualize_feature_maps(model, input_tensor, layer_name, save_folder="feature_maps"):
"""
Visualize feature maps from a specific layer given an input
"""
Path(save_folder).mkdir(parents=True, exist_ok=True)
# Hook to capture feature maps
feature_maps = {}
def hook_fn(module, input, output):
feature_maps[layer_name] = output.detach()
# Register hook
layer = dict(model.named_modules())[layer_name]
handle = layer.register_forward_hook(hook_fn)
# Forward pass
model.eval()
with torch.no_grad():
_ = model(input_tensor)
# Remove hook
handle.remove()
# Get feature maps
fmaps = feature_maps[layer_name].cpu().numpy()
print(f"Feature maps shape: {fmaps.shape}") # [batch, channels, height, width]
# Take first sample from batch
fmaps = fmaps[0] # [channels, height, width]
# Show first 64 feature maps
max_maps = min(64, fmaps.shape[0])
cols = 8
rows = (max_maps + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(16, 2*rows))
if rows == 1:
axes = axes.reshape(1, -1)
for i in range(max_maps):
row = i // cols
col = i % cols
fmap = fmaps[i]
axes[row, col].imshow(fmap, cmap='viridis')
axes[row, col].set_title(f'Channel {i}', fontsize=8)
axes[row, col].axis('off')
# Hide unused subplots
for i in range(max_maps, rows * cols):
row = i // cols
col = i % cols
axes[row, col].axis('off')
plt.suptitle(f'Feature Maps from {layer_name}', fontsize=14)
plt.tight_layout()
plt.savefig(f'{save_folder}/{layer_name.replace(".", "_")}_fmaps.png', dpi=300, bbox_inches='tight')
plt.show()
return fmaps
def analyze_model_patterns(checkpoint_path, sample_input=None):
"""
Main function to analyze kernel patterns and feature maps
"""
# Load model
model = ResUnetPlusPlus(input_ch=3, n_class=3)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
# Print all layer names to help identify layers of interest
print("Available layers:")
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
print(f" {name}: {type(module).__name__} {module.weight.shape}")
# Create sample input if not provided
if sample_input is None:
sample_input = torch.randn(1, 3, 256, 256).to(device)
# Analyze different layers - modify these based on your model architecture
layers_to_analyze = [
# First layer (typically captures edges/basic patterns)
"encoder.0", # or whatever your first conv layer is named
# Middle layers (intermediate features)
"encoder.4", # adjust based on your model
# Final layers (high-level features)
"decoder.0", # adjust based on your model
"final_conv", # or your final output layer
]
# You may need to adjust layer names based on your actual model structure
# Uncomment and modify the layer names that exist in your model:
# Example for common layer names:
# layers_to_analyze = [
# "conv1",
# "layer1.0.conv1",
# "layer2.0.conv1",
# "layer3.0.conv1",
# "layer4.0.conv1",
# "classifier" # final layer
# ]
print("\nAnalyzing kernel patterns...")
for layer_name in layers_to_analyze:
try:
print(f"\n--- Analyzing {layer_name} ---")
# Visualize kernels
kernels = visualize_conv_kernels(model, layer_name)
# Visualize feature maps
feature_maps = visualize_feature_maps(model, sample_input, layer_name)
except KeyError:
print(f"Layer '{layer_name}' not found in model. Skipping...")
continue
except Exception as e:
print(f"Error analyzing layer '{layer_name}': {e}")
continue
def visualize_conv_kernels_rgb(model, layer_name, save_folder="kernel_patterns"):
"""
Visualize convolutional kernels as RGB-style images
"""
Path(save_folder).mkdir(parents=True, exist_ok=True)
layer = dict(model.named_modules())[layer_name]
if not isinstance(layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
print(f"Layer {layer_name} is not a convolutional layer")
return
kernels = layer.weight.data.cpu().numpy()
print(f"Layer: {layer_name}, Kernel shape: {kernels.shape}")
num_filters = kernels.shape[0]
num_input_channels = kernels.shape[1]
# Show up to 32 filters for better visualization
max_filters = min(32, num_filters)
cols = 8
rows = (max_filters + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(16, 2*rows))
if rows == 1:
axes = axes.reshape(1, -1)
for i in range(max_filters):
row = i // cols
col = i % cols
kernel = kernels[i]
if num_input_channels >= 3:
# Take first 3 channels as RGB
kernel_rgb = kernel[:3] # Shape: [3, H, W]
kernel_rgb = np.transpose(kernel_rgb, (1, 2, 0)) # [H, W, 3]
# Normalize each channel independently
for c in range(3):
channel = kernel_rgb[:, :, c]
kernel_rgb[:, :, c] = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8)
elif num_input_channels == 1:
# Convert single channel to RGB (grayscale)
kernel_gray = kernel[0]
kernel_gray = (kernel_gray - kernel_gray.min()) / (kernel_gray.max() - kernel_gray.min() + 1e-8)
kernel_rgb = np.stack([kernel_gray] * 3, axis=-1)
else:
# For other cases, use mean across input channels
kernel_mean = np.mean(kernel, axis=0)
kernel_mean = (kernel_mean - kernel_mean.min()) / (kernel_mean.max() - kernel_mean.min() + 1e-8)
kernel_rgb = np.stack([kernel_mean] * 3, axis=-1)
axes[row, col].imshow(kernel_rgb)
axes[row, col].set_title(f'Filter {i}', fontsize=8)
axes[row, col].axis('off')
# Hide unused subplots
for i in range(max_filters, rows * cols):
row = i // cols
col = i % cols
axes[row, col].axis('off')
plt.suptitle(f'RGB-style Kernels from {layer_name}', fontsize=14)
plt.tight_layout()
plt.savefig(f'{save_folder}/{layer_name.replace(".", "_")}_rgb_kernels.png', dpi=300, bbox_inches='tight')
plt.show()
return kernels
def generate_grad_cam(model, input_tensor, target_layer, target_class=None):
"""
Generate Grad-CAM heatmap for understanding what the network focuses on
"""
model.eval()
input_tensor.requires_grad_(True) # Enable gradients on input
# Store gradients and feature maps
gradients = []
activations = []
def backward_hook(module, grad_input, grad_output):
if grad_output[0] is not None:
gradients.append(grad_output[0].detach())
def forward_hook(module, input, output):
activations.append(output.detach())
try:
# Register hooks
target_layer_module = dict(model.named_modules())[target_layer]
h1 = target_layer_module.register_backward_hook(backward_hook)
h2 = target_layer_module.register_forward_hook(forward_hook)
# Forward pass
output = model(input_tensor)
# If no target class specified, use the class with highest probability
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass
model.zero_grad()
# Get the score for target class (handle different output shapes)
if len(output.shape) == 4: # [B, C, H, W] - segmentation output
class_score = output[0, target_class].mean()
else: # [B, C] - classification output
class_score = output[0, target_class]
class_score.backward(retain_graph=True)
# Remove hooks
h1.remove()
h2.remove()
# Check if we got gradients and activations
if not gradients or not activations:
print(f"No gradients or activations captured for layer {target_layer}")
return np.zeros((64, 64)), target_class
# Generate Grad-CAM
gradients_tensor = gradients[0][0] # [C, H, W]
activations_tensor = activations[0][0] # [C, H, W]
# Ensure tensors are on the same device
device = activations_tensor.device
gradients_tensor = gradients_tensor.to(device)
# Calculate weights (global average pooling of gradients)
weights = torch.mean(gradients_tensor, dim=(1, 2)) # [C]
# Generate heatmap on the same device
cam = torch.zeros(activations_tensor.shape[1:], dtype=torch.float32, device=device)
for i, w in enumerate(weights):
cam += w * activations_tensor[i, :, :]
cam = torch.clamp(cam, min=0) # ReLU
cam = cam / cam.max() if cam.max() > 0 else cam # Normalize
return cam.detach().cpu().numpy(), target_class
except Exception as e:
print(f"Error in Grad-CAM for layer {target_layer}: {e}")
return np.zeros((64, 64)), 0
def visualize_layer_attention(model, input_image, checkpoint_path, save_folder="attention_maps"):
"""
Visualize what different layers focus on for a given input image
"""
Path(save_folder).mkdir(parents=True, exist_ok=True)
# Load model
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
# Prepare input
if isinstance(input_image, str):
# Load image from path
img = cv2.imread(input_image)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256))
img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0
else:
img_tensor = input_image
img = (input_image.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
img_tensor = img_tensor.to(device)
# Updated key layers based on your actual ResUNet++ architecture
key_layers = [
'c1.c1.0', # First conv in stem block
'c2.c1.2', # First conv in ResNet block c2
'c3.c1.2', # First conv in ResNet block c3
'c4.c1.2', # First conv in ResNet block c4
'b1.c1.0', # First branch in ASPP
'd1.r1.c1.2', # First decoder ResNet block
'd2.r1.c1.2', # Second decoder ResNet block
'output' # Final output layer
]
# Create figure with appropriate number of subplots
num_plots = len(key_layers) + 1 # +1 for original image
cols = 3
rows = (num_plots + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
if rows == 1:
axes = axes.reshape(1, -1)
axes = axes.flatten()
# Show original image
axes[0].imshow(img)
axes[0].set_title('Original Image')
axes[0].axis('off')
print("Generating attention maps...")
valid_idx = 1
for layer_name in key_layers:
try:
print(f"Processing layer: {layer_name}")
# Check if layer exists
if layer_name not in dict(model.named_modules()):
print(f"Layer {layer_name} not found, skipping...")
continue
# Generate Grad-CAM (try class 1 first, then fallback)
heatmap, target_class = generate_grad_cam(model, img_tensor.clone(), layer_name, target_class=1)
if heatmap.max() == 0: # If no activation, try different class
heatmap, target_class = generate_grad_cam(model, img_tensor.clone(), layer_name, target_class=None)
if heatmap.max() == 0: # Still no activation, create placeholder
heatmap = np.random.rand(64, 64) * 0.1
print(f"No significant activation in {layer_name}")
# Resize heatmap to match input image size
if heatmap.shape != (256, 256):
heatmap = cv2.resize(heatmap, (256, 256))
# Create colored heatmap
heatmap_normalized = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
heatmap_colored = plt.cm.jet(heatmap_normalized)[:, :, :3]
# Overlay on original image
overlay = 0.6 * (img/255.0) + 0.4 * heatmap_colored
overlay = np.clip(overlay, 0, 1)
if valid_idx < len(axes):
axes[valid_idx].imshow(overlay)
axes[valid_idx].set_title(f'{layer_name}\n(Class: {target_class})', fontsize=10)
axes[valid_idx].axis('off')
# Save individual attention map
plt.imsave(f'{save_folder}/attention_{layer_name.replace(".", "_")}.png',
overlay, dpi=300)
valid_idx += 1
except Exception as e:
print(f"Error processing layer {layer_name}: {e}")
continue
# Hide unused subplots
for i in range(valid_idx, len(axes)):
axes[i].axis('off')
plt.suptitle('Layer-wise Attention Maps (Grad-CAM)', fontsize=16)
plt.tight_layout()
plt.savefig(f'{save_folder}/combined_attention_analysis.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Attention maps saved to: {save_folder}")
return True
def visualize_feature_evolution(model, input_image, checkpoint_path, save_folder="feature_evolution"):
"""
Show how features evolve through the network layers
"""
Path(save_folder).mkdir(parents=True, exist_ok=True)
# Load model
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
# Prepare input
if isinstance(input_image, str):
img = cv2.imread(input_image)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (256, 256))
img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0
else:
img_tensor = input_image
img = (input_image.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
img_tensor = img_tensor.to(device)
# Layers to track feature evolution
evolution_layers = [
'c1', # Stem block
'c2', # First encoder
'c3', # Second encoder
'c4', # Third encoder
'b1', # Bridge
'd1', # First decoder
'd2', # Second decoder
'd3', # Third decoder
'output' # Final output
]
# Hook to capture intermediate outputs
intermediate_outputs = {}
hooks = []
def make_hook(name):
def hook(module, input, output):
intermediate_outputs[name] = output.detach()
return hook
# Register hooks
for layer_name in evolution_layers:
try:
layer = dict(model.named_modules())[layer_name]
hook = layer.register_forward_hook(make_hook(layer_name))
hooks.append(hook)
except KeyError:
print(f"Layer {layer_name} not found")
# Forward pass
with torch.no_grad():
output = model(img_tensor)
# Remove hooks
for hook in hooks:
hook.remove()
# Visualize evolution
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.flatten()
# Show original image
axes[0].imshow(img)
axes[0].set_title('Input Image')
axes[0].axis('off')
for idx, layer_name in enumerate(evolution_layers):
if layer_name in intermediate_outputs and idx + 1 < len(axes):
feature_map = intermediate_outputs[layer_name][0] # Remove batch dimension
if len(feature_map.shape) == 3: # [C, H, W]
# Take mean across channels for visualization
feature_viz = torch.mean(feature_map, dim=0).cpu().numpy()
else: # Already 2D
feature_viz = feature_map.cpu().numpy()
# Normalize
feature_viz = (feature_viz - feature_viz.min()) / (feature_viz.max() - feature_viz.min() + 1e-8)
axes[idx + 1].imshow(feature_viz, cmap='viridis')
axes[idx + 1].set_title(f'{layer_name}\nShape: {feature_map.shape}')
axes[idx + 1].axis('off')
plt.suptitle('Feature Evolution Through Network', fontsize=16)
plt.tight_layout()
plt.savefig(f'{save_folder}/feature_evolution.png', dpi=300, bbox_inches='tight')
plt.show()
return intermediate_outputs
def comprehensive_network_analysis(checkpoint_path, sample_image_path):
"""
Comprehensive analysis of what the network sees and understands
"""
print("="*80)
print("COMPREHENSIVE NETWORK ANALYSIS")
print("="*80)
# Load model
model = ResUnetPlusPlus(input_ch=3, n_class=3)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
# 1. RGB-style kernel visualization
print("\n1. Analyzing RGB-style filter patterns...")
key_conv_layers = ['c1.c1.0', 'c2.c1.2', 'c4.c1.2', 'output']
for layer in key_conv_layers:
try:
visualize_conv_kernels_rgb(model, layer)
except:
print(f"Could not analyze layer {layer}")
# 2. Layer attention analysis
print("\n2. Analyzing layer-wise attention...")
visualize_layer_attention(model, sample_image_path, checkpoint_path)
# 3. Feature evolution
print("\n3. Analyzing feature evolution...")
visualize_feature_evolution(model, sample_image_path, checkpoint_path)
print("\nAnalysis complete! Check the generated folders for results.")
# Usage examples:
comprehensive_network_analysis("multiclass_model_998.pt", "./test_images/0048_3.png")
================================================================================ COMPREHENSIVE NETWORK ANALYSIS ================================================================================ 1. Analyzing RGB-style filter patterns... Layer: c1.c1.0, Kernel shape: (16, 3, 3, 3)
Layer: c2.c1.2, Kernel shape: (32, 16, 3, 3)
Layer: c4.c1.2, Kernel shape: (128, 64, 3, 3)
Layer: output, Kernel shape: (3, 16, 1, 1)
2. Analyzing layer-wise attention... Generating attention maps... Processing layer: c1.c1.0
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1864: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Processing layer: c2.c1.2 Processing layer: c3.c1.2 Processing layer: c4.c1.2 Processing layer: b1.c1.0 Processing layer: d1.r1.c1.2 Processing layer: d2.r1.c1.2 Processing layer: output
Attention maps saved to: attention_maps 3. Analyzing feature evolution...
Analysis complete! Check the generated folders for results.
In [ ]:
# # Or individual analyses:
# model = ResUnetPlusPlus(input_ch=3, n_class=3)
# checkpoint = torch.load("multiclass_model_998.pt")
# model.load_state_dict(checkpoint['state_dict'])
# visualize_conv_kernels_rgb(model, "c1.c1.0") # First layer RGB kernels
In [ ]: