Convolutional Neural Networks (CNNs) tutorial for CS-772 2025

Author: Jimut Bahan Pal

Multiclass Image Segmentation - Segmenting wild plants¶

The original augmented dataset can be downloaded from here.

A mirror copy of the dataset can also be found here.

Install the necessary libraries¶

In [ ]:
! pip install -q gdown
! pip install -q colorama # for color-based texts
! pip install -q torchviz # for visualizing graphs
! pip install -q torchview # for visualizing graphs
! pip install -q graphviz # for visualizing graphs
! pip install -q torchsummary # for finding the number of parameters of a model
In [ ]:
# delete dataset if already present
! rm -rf dataset.zip

Download the dataset from Google Drive¶

In [ ]:
! gdown --id 1Ihjzt59qzBmI6JbBgJZpZyTg1wjZvKnE -O dataset.zip
/usr/local/lib/python3.12/dist-packages/gdown/__main__.py:140: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
  warnings.warn(
Downloading...
From (original): https://drive.google.com/uc?id=1Ihjzt59qzBmI6JbBgJZpZyTg1wjZvKnE
From (redirected): https://drive.google.com/uc?id=1Ihjzt59qzBmI6JbBgJZpZyTg1wjZvKnE&confirm=t&uuid=c5388783-590f-4505-b990-9ace5b78ac10
To: /content/dataset.zip
100% 707M/707M [00:10<00:00, 66.7MB/s]
In [ ]:
! unzip -qq dataset.zip

Check the number of images present in the whole dataset¶

In [ ]:
! ls weed_augmented/images/* | wc
   7872    7872  259776

Imports¶

In [ ]:
# General imports
import os
import cv2
import glob
import numpy as np
import random
from tqdm import tqdm
from pathlib import Path
from colorama import Fore, Style
from collections import Counter, defaultdict

# Model based imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Viz-based imports
from torchsummary import summary
from torchview import draw_graph
import graphviz
graphviz.set_jupyter_format('png')
import seaborn as sns
import matplotlib.pyplot as plt

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
Using device: cuda

This dataset contains three pixel classes; here, we examine the percentage of pixels belonging to each class.¶

In [ ]:
def analyze_3class_masks(mask_folder="./weed_augmented/masks/"):
    """
    Fast analysis for 3-class segmentation masks with specific RGB values.
    """
    # Define the 3 classes
    classes = {
        'background': np.array([0, 0, 0]),      # Black
        'class_1': np.array([0, 128, 0]),       # Green
        'class_2': np.array([0, 0, 128])        # Blue
    }

    # Initialize counters
    counts = {name: 0 for name in classes.keys()}
    total_pixels = 0

    # Get all mask files
    mask_files = glob.glob(f"{mask_folder}*.png")
    print(f"Processing {len(mask_files)} mask files...")

    # Process each mask
    for mask_file in tqdm(mask_files, desc="Analyzing masks"):
        # Read mask (BGR format)
        mask_rgb = cv2.imread(mask_file)
        if mask_rgb is None:
            continue

        # Convert BGR to RGB
        # mask_rgb = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
        h, w = mask_rgb.shape[:2]
        total_pixels += h * w

        # Count pixels for each class using vectorized operations
        for class_name, class_color in classes.items():
            # Create boolean mask for this class
            matches = np.all(mask_rgb == class_color, axis=2)
            counts[class_name] += np.sum(matches)

    # Calculate percentages
    print(f"\n{'Class':<12} {'Pixels':<12} {'Percentage'}")
    print("-" * 35)

    for class_name, pixel_count in counts.items():
        percentage = (pixel_count / total_pixels) * 100 if total_pixels > 0 else 0
        print(f"{class_name:<12} {pixel_count:<12,} {percentage:.2f}%")

    print(f"\nTotal pixels: {total_pixels:,}")
    print(f"Files processed: {len(mask_files)}")

    return counts, total_pixels


analyze_3class_masks()
Processing 7872 mask files...
Analyzing masks: 100%|██████████| 7872/7872 [03:49<00:00, 34.25it/s]
Class        Pixels       Percentage
-----------------------------------
background   1,662,369,971 80.56%
class_1      193,860,104  9.39%
class_2      206,718,645  10.02%

Total pixels: 2,063,597,568
Files processed: 7872

Out[ ]:
({'background': np.int64(1662369971),
  'class_1': np.int64(193860104),
  'class_2': np.int64(206718645)},
 2063597568)

Creating the train-val-test split, with 60% of the dataset for training, 10% of the dataset for validation and 30% of the dataset for testing; also resizing the images to 256x256 size before creating the dataset. Store the image masks pairs in the respective folders, i.e., train_images, train_masks, val_images, val_masks, test_images, and test_masks.¶

In [ ]:
# Configuration
source_folder = 'weed_augmented'
output_folder = '.'
# use a seed of 42 for reproducibility
# random.seed(42)


# Create directories
dirs = ['train_images', 'train_masks', 'val_images', 'val_masks', 'test_images', 'test_masks']
for dir_name in dirs:
    os.makedirs(os.path.join(output_folder, dir_name), exist_ok=True)

# Get matching image-mask pairs
images_path = os.path.join(source_folder, 'images')
masks_path = os.path.join(source_folder, 'masks')

image_files = [f for f in os.listdir(images_path) if f.lower().endswith('.jpg')]
pairs = []

for img_file in image_files:
    mask_file = img_file.replace('.jpg', '.png')
    if os.path.exists(os.path.join(masks_path, mask_file)):
        pairs.append((img_file, mask_file))

print(f"Found {len(pairs)} matching image-mask pairs")

# Shuffle and split
random.shuffle(pairs)
n_total = len(pairs)
n_train = int(n_total * 0.6)
n_val = int(n_total * 0.1)

train_pairs = pairs[:n_train]
val_pairs = pairs[n_train:n_train + n_val]
test_pairs = pairs[n_train + n_val:]

print(f"Split: Train={len(train_pairs)}, Val={len(val_pairs)}, Test={len(test_pairs)}")


# Process and resize images
target_size = (256, 256)

for pairs_list, split_name in [(train_pairs, 'train'), (val_pairs, 'val'), (test_pairs, 'test')]:
    for img_file, mask_file_name in tqdm(pairs_list):
        # Load and resize image
        img = cv2.imread(os.path.join(images_path, img_file))
        img_resized = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)

        # Load and resize mask
        mask = cv2.imread(os.path.join(masks_path, mask_file_name))
        mask_resized = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)

        # Save resized images
        cv2.imwrite(os.path.join(output_folder, f'{split_name}_images', mask_file_name), img_resized)
        cv2.imwrite(os.path.join(output_folder, f'{split_name}_masks', mask_file_name), mask_resized)
Found 7872 matching image-mask pairs
Split: Train=4723, Val=787, Test=2362
100%|██████████| 4723/4723 [00:49<00:00, 94.84it/s]
100%|██████████| 787/787 [00:08<00:00, 95.92it/s]
100%|██████████| 2362/2362 [00:26<00:00, 89.98it/s]

Creating overlays of the masks with the images, and showing first 5 samples from each of the train, validation and test splits¶

In [ ]:
# Analyze mask colors
mask_folder = os.path.join(output_folder, 'train_masks')
mask_files = [f for f in os.listdir(mask_folder) if f.lower().endswith('.png')]
sample_files = random.sample(mask_files, min(20, len(mask_files)))

all_colors = []
for filename in sample_files:
    mask = cv2.imread(os.path.join(mask_folder, filename))
    mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
    unique_colors = np.unique(mask_rgb.reshape(-1, 3), axis=0)
    all_colors.extend([tuple(color) for color in unique_colors])

color_counts = Counter(all_colors)
class_colors = [color for color, _ in color_counts.most_common(10)]

print("Most common mask colors (RGB):")
for color in class_colors[:5]:
    print(f"  {color}")

# Visualize samples
def create_overlay(image, mask, colors, alpha=0.6):
    overlay = image.copy()
    mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

    for color in colors:
        if color == (0, 0, 0):
            continue
        color_mask = np.all(mask_rgb == color, axis=2)
        if np.any(color_mask):
            overlay[color_mask] = (alpha * np.array(color) + (1 - alpha) * overlay[color_mask]).astype(np.uint8)
    return overlay

# Show samples for each split
for split in ['train', 'val', 'test']:
    img_folder = os.path.join(output_folder, f'{split}_images')
    mask_folder = os.path.join(output_folder, f'{split}_masks')

    img_files = sorted([f for f in os.listdir(img_folder) if f.lower().endswith('.png')])[:5]

    if not img_files:
        continue

    fig, axes = plt.subplots(3, len(img_files), figsize=(10, 8))
    fig.suptitle(f'{split.capitalize()} Set - First {len(img_files)} Samples', fontsize=16)

    for i, img_file in enumerate(img_files):
        mask_file = img_file

        img = cv2.imread(os.path.join(img_folder, img_file))
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(os.path.join(mask_folder, mask_file))
        mask_rgb = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        overlay = create_overlay(img_rgb, mask, class_colors)

        axes[0, i].imshow(img_rgb)
        axes[0, i].set_title(f'Image {i+1}', fontsize=10)
        axes[0, i].axis('off')

        axes[1, i].imshow(mask_rgb)
        axes[1, i].set_title(f'Mask {i+1}', fontsize=10)
        axes[1, i].axis('off')

        axes[2, i].imshow(overlay)
        axes[2, i].set_title(f'Overlay {i+1}', fontsize=10)
        axes[2, i].axis('off')

    plt.tight_layout()
    plt.show()

print(f"Dataset organized in '{output_folder}' folder")
Most common mask colors (RGB):
  (np.uint8(0), np.uint8(0), np.uint8(0))
  (np.uint8(128), np.uint8(0), np.uint8(0))
  (np.uint8(0), np.uint8(128), np.uint8(0))
  (np.uint8(1), np.uint8(0), np.uint8(0))
  (np.uint8(2), np.uint8(0), np.uint8(0))
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Dataset organized in '.' folder

We use the ResUNet++ architecture for image segmentation, which is very lightweight in itself.¶

ResUNet++ Architecture

In [ ]:
# https://github.com/DebeshJha/ResUNetPlusPlus-with-CRF-and-TTA/blob/master/resunet%2B%2B_pytorch.py
# Note: This uses 3-channel input (in standard channel-first mode)

class Squeeze_Excitation(nn.Module):
    def __init__(self, channel, r=8):
        super().__init__()

        self.pool = nn.AdaptiveAvgPool2d(1)
        self.net = nn.Sequential(
            nn.Linear(channel, channel // r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // r, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, inputs):
        b, c, _, _ = inputs.shape
        x = self.pool(inputs).view(b, c)
        x = self.net(x).view(b, c, 1, 1)
        x = inputs * x
        return x

class Stem_Block(nn.Module):
    def __init__(self, in_c, out_c, stride):
        super().__init__()

        self.c1 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
        )

        self.c2 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0),
            nn.BatchNorm2d(out_c),
        )

        self.attn = Squeeze_Excitation(out_c)

    def forward(self, inputs):
        x = self.c1(inputs)
        s = self.c2(inputs)
        y = self.attn(x + s)
        return y

class ResNet_Block(nn.Module):
    def __init__(self, in_c, out_c, stride):
        super().__init__()

        self.c1 = nn.Sequential(
            nn.BatchNorm2d(in_c),
            nn.ReLU(),
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        )

        self.c2 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, stride=stride, padding=0),
            nn.BatchNorm2d(out_c),
        )

        self.attn = Squeeze_Excitation(out_c)

    def forward(self, inputs):
        x = self.c1(inputs)
        s = self.c2(inputs)
        y = self.attn(x + s)
        return y

class ASPP(nn.Module):
    def __init__(self, in_c, out_c, rate=[1, 6, 12, 18]):
        super().__init__()

        self.c1 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[0], padding=rate[0]),
            nn.BatchNorm2d(out_c)
        )

        self.c2 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[1], padding=rate[1]),
            nn.BatchNorm2d(out_c)
        )

        self.c3 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[2], padding=rate[2]),
            nn.BatchNorm2d(out_c)
        )

        self.c4 = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, dilation=rate[3], padding=rate[3]),
            nn.BatchNorm2d(out_c)
        )

        self.c5 = nn.Conv2d(out_c, out_c, kernel_size=1, padding=0)


    def forward(self, inputs):
        x1 = self.c1(inputs)
        x2 = self.c2(inputs)
        x3 = self.c3(inputs)
        x4 = self.c4(inputs)
        x = x1 + x2 + x3 + x4
        y = self.c5(x)
        return y

class Attention_Block(nn.Module):
    def __init__(self, in_c):
        super().__init__()
        out_c = in_c[1]

        self.g_conv = nn.Sequential(
            nn.BatchNorm2d(in_c[0]),
            nn.ReLU(),
            nn.Conv2d(in_c[0], out_c, kernel_size=3, padding=1),
            nn.MaxPool2d((2, 2))
        )

        self.x_conv = nn.Sequential(
            nn.BatchNorm2d(in_c[1]),
            nn.ReLU(),
            nn.Conv2d(in_c[1], out_c, kernel_size=3, padding=1),
        )

        self.gc_conv = nn.Sequential(
            nn.BatchNorm2d(in_c[1]),
            nn.ReLU(),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
        )

    def forward(self, g, x):
        g_pool = self.g_conv(g)
        x_conv = self.x_conv(x)
        gc_sum = g_pool + x_conv
        gc_conv = self.gc_conv(gc_sum)
        y = gc_conv * x
        return y

class Decoder_Block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.a1 = Attention_Block(in_c)
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.r1 = ResNet_Block(in_c[0]+in_c[1], out_c, stride=1)

    def forward(self, g, x):
        d = self.a1(g, x)
        d = self.up(d)
        d = torch.cat([d, g], axis=1)
        d = self.r1(d)
        return d

class ResUnetPlusPlus(nn.Module):
    def __init__(self, input_ch, n_class):
        super().__init__()

        self.num_classes = n_class

        self.c1 = Stem_Block(input_ch, 16, stride=1)
        self.c2 = ResNet_Block(16, 32, stride=2)
        self.c3 = ResNet_Block(32, 64, stride=2)
        self.c4 = ResNet_Block(64, 128, stride=2)

        self.b1 = ASPP(128, 256)

        self.d1 = Decoder_Block([64, 256], 128)
        self.d2 = Decoder_Block([32, 128], 64)
        self.d3 = Decoder_Block([16, 64], 32)

        self.aspp = ASPP(32, 16)
        self.output = nn.Conv2d(16, n_class, kernel_size=1, padding=0)

    def forward(self, inputs):
        c1 = self.c1(inputs)
        c2 = self.c2(c1)
        c3 = self.c3(c2)
        c4 = self.c4(c3)

        b1 = self.b1(c4)

        d1 = self.d1(c3, b1)
        d2 = self.d2(c2, d1)
        d3 = self.d3(c1, d2)

        output = self.aspp(d3)
        output = self.output(output)

        if self.num_classes == 1:
            output = torch.sigmoid(output)
        elif self.num_classes > 1:
            output = F.softmax(output, dim = 1)

        return output
In [ ]:
# check the model structure
#CALLING THE MODEL
model = ResUnetPlusPlus(input_ch=3, n_class=3)
# model = nn.DataParallel(model)
model = model.to(device)
print(model)
ResUnetPlusPlus(
  (c1): Stem_Block(
    (c1): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (c2): Sequential(
      (0): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (attn): Squeeze_Excitation(
      (pool): AdaptiveAvgPool2d(output_size=1)
      (net): Sequential(
        (0): Linear(in_features=16, out_features=2, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=2, out_features=16, bias=False)
        (3): Sigmoid()
      )
    )
  )
  (c2): ResNet_Block(
    (c1): Sequential(
      (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU()
      (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU()
      (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (c2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (attn): Squeeze_Excitation(
      (pool): AdaptiveAvgPool2d(output_size=1)
      (net): Sequential(
        (0): Linear(in_features=32, out_features=4, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=4, out_features=32, bias=False)
        (3): Sigmoid()
      )
    )
  )
  (c3): ResNet_Block(
    (c1): Sequential(
      (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU()
      (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (c2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (attn): Squeeze_Excitation(
      (pool): AdaptiveAvgPool2d(output_size=1)
      (net): Sequential(
        (0): Linear(in_features=64, out_features=8, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=8, out_features=64, bias=False)
        (3): Sigmoid()
      )
    )
  )
  (c4): ResNet_Block(
    (c1): Sequential(
      (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU()
      (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU()
      (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (c2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (attn): Squeeze_Excitation(
      (pool): AdaptiveAvgPool2d(output_size=1)
      (net): Sequential(
        (0): Linear(in_features=128, out_features=16, bias=False)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=16, out_features=128, bias=False)
        (3): Sigmoid()
      )
    )
  )
  (b1): ASPP(
    (c1): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c4): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (d1): Decoder_Block(
    (a1): Attention_Block(
      (g_conv): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
      (x_conv): Sequential(
        (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (gc_conv): Sequential(
        (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (up): Upsample(scale_factor=2.0, mode='nearest')
    (r1): ResNet_Block(
      (c1): Sequential(
        (0): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(320, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU()
        (5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (c2): Sequential(
        (0): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (attn): Squeeze_Excitation(
        (pool): AdaptiveAvgPool2d(output_size=1)
        (net): Sequential(
          (0): Linear(in_features=128, out_features=16, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=16, out_features=128, bias=False)
          (3): Sigmoid()
        )
      )
    )
  )
  (d2): Decoder_Block(
    (a1): Attention_Block(
      (g_conv): Sequential(
        (0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
      (x_conv): Sequential(
        (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (gc_conv): Sequential(
        (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (up): Upsample(scale_factor=2.0, mode='nearest')
    (r1): ResNet_Block(
      (c1): Sequential(
        (0): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(160, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU()
        (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (c2): Sequential(
        (0): Conv2d(160, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (attn): Squeeze_Excitation(
        (pool): AdaptiveAvgPool2d(output_size=1)
        (net): Sequential(
          (0): Linear(in_features=64, out_features=8, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=8, out_features=64, bias=False)
          (3): Sigmoid()
        )
      )
    )
  )
  (d3): Decoder_Block(
    (a1): Attention_Block(
      (g_conv): Sequential(
        (0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
      )
      (x_conv): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (gc_conv): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (up): Upsample(scale_factor=2.0, mode='nearest')
    (r1): ResNet_Block(
      (c1): Sequential(
        (0): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(80, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU()
        (5): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (c2): Sequential(
        (0): Conv2d(80, 32, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (attn): Squeeze_Excitation(
        (pool): AdaptiveAvgPool2d(output_size=1)
        (net): Sequential(
          (0): Linear(in_features=32, out_features=4, bias=False)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=4, out_features=32, bias=False)
          (3): Sigmoid()
        )
      )
    )
  )
  (aspp): ASPP(
    (c1): Sequential(
      (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c2): Sequential(
      (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c3): Sequential(
      (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c4): Sequential(
      (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (c5): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (output): Conv2d(16, 3, kernel_size=(1, 1), stride=(1, 1))
)

Check the parameters of the model¶

In [ ]:
# check the model parameters
# channel-first input
summary(model, input_size=(3, 256, 256))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 256, 256]             448
       BatchNorm2d-2         [-1, 16, 256, 256]              32
              ReLU-3         [-1, 16, 256, 256]               0
            Conv2d-4         [-1, 16, 256, 256]           2,320
            Conv2d-5         [-1, 16, 256, 256]              64
       BatchNorm2d-6         [-1, 16, 256, 256]              32
 AdaptiveAvgPool2d-7             [-1, 16, 1, 1]               0
            Linear-8                    [-1, 2]              32
              ReLU-9                    [-1, 2]               0
           Linear-10                   [-1, 16]              32
          Sigmoid-11                   [-1, 16]               0
Squeeze_Excitation-12         [-1, 16, 256, 256]               0
       Stem_Block-13         [-1, 16, 256, 256]               0
      BatchNorm2d-14         [-1, 16, 256, 256]              32
             ReLU-15         [-1, 16, 256, 256]               0
           Conv2d-16         [-1, 32, 128, 128]           4,640
      BatchNorm2d-17         [-1, 32, 128, 128]              64
             ReLU-18         [-1, 32, 128, 128]               0
           Conv2d-19         [-1, 32, 128, 128]           9,248
           Conv2d-20         [-1, 32, 128, 128]             544
      BatchNorm2d-21         [-1, 32, 128, 128]              64
AdaptiveAvgPool2d-22             [-1, 32, 1, 1]               0
           Linear-23                    [-1, 4]             128
             ReLU-24                    [-1, 4]               0
           Linear-25                   [-1, 32]             128
          Sigmoid-26                   [-1, 32]               0
Squeeze_Excitation-27         [-1, 32, 128, 128]               0
     ResNet_Block-28         [-1, 32, 128, 128]               0
      BatchNorm2d-29         [-1, 32, 128, 128]              64
             ReLU-30         [-1, 32, 128, 128]               0
           Conv2d-31           [-1, 64, 64, 64]          18,496
      BatchNorm2d-32           [-1, 64, 64, 64]             128
             ReLU-33           [-1, 64, 64, 64]               0
           Conv2d-34           [-1, 64, 64, 64]          36,928
           Conv2d-35           [-1, 64, 64, 64]           2,112
      BatchNorm2d-36           [-1, 64, 64, 64]             128
AdaptiveAvgPool2d-37             [-1, 64, 1, 1]               0
           Linear-38                    [-1, 8]             512
             ReLU-39                    [-1, 8]               0
           Linear-40                   [-1, 64]             512
          Sigmoid-41                   [-1, 64]               0
Squeeze_Excitation-42           [-1, 64, 64, 64]               0
     ResNet_Block-43           [-1, 64, 64, 64]               0
      BatchNorm2d-44           [-1, 64, 64, 64]             128
             ReLU-45           [-1, 64, 64, 64]               0
           Conv2d-46          [-1, 128, 32, 32]          73,856
      BatchNorm2d-47          [-1, 128, 32, 32]             256
             ReLU-48          [-1, 128, 32, 32]               0
           Conv2d-49          [-1, 128, 32, 32]         147,584
           Conv2d-50          [-1, 128, 32, 32]           8,320
      BatchNorm2d-51          [-1, 128, 32, 32]             256
AdaptiveAvgPool2d-52            [-1, 128, 1, 1]               0
           Linear-53                   [-1, 16]           2,048
             ReLU-54                   [-1, 16]               0
           Linear-55                  [-1, 128]           2,048
          Sigmoid-56                  [-1, 128]               0
Squeeze_Excitation-57          [-1, 128, 32, 32]               0
     ResNet_Block-58          [-1, 128, 32, 32]               0
           Conv2d-59          [-1, 256, 32, 32]         295,168
      BatchNorm2d-60          [-1, 256, 32, 32]             512
           Conv2d-61          [-1, 256, 32, 32]         295,168
      BatchNorm2d-62          [-1, 256, 32, 32]             512
           Conv2d-63          [-1, 256, 32, 32]         295,168
      BatchNorm2d-64          [-1, 256, 32, 32]             512
           Conv2d-65          [-1, 256, 32, 32]         295,168
      BatchNorm2d-66          [-1, 256, 32, 32]             512
           Conv2d-67          [-1, 256, 32, 32]          65,792
             ASPP-68          [-1, 256, 32, 32]               0
      BatchNorm2d-69           [-1, 64, 64, 64]             128
             ReLU-70           [-1, 64, 64, 64]               0
           Conv2d-71          [-1, 256, 64, 64]         147,712
        MaxPool2d-72          [-1, 256, 32, 32]               0
      BatchNorm2d-73          [-1, 256, 32, 32]             512
             ReLU-74          [-1, 256, 32, 32]               0
           Conv2d-75          [-1, 256, 32, 32]         590,080
      BatchNorm2d-76          [-1, 256, 32, 32]             512
             ReLU-77          [-1, 256, 32, 32]               0
           Conv2d-78          [-1, 256, 32, 32]         590,080
  Attention_Block-79          [-1, 256, 32, 32]               0
         Upsample-80          [-1, 256, 64, 64]               0
      BatchNorm2d-81          [-1, 320, 64, 64]             640
             ReLU-82          [-1, 320, 64, 64]               0
           Conv2d-83          [-1, 128, 64, 64]         368,768
      BatchNorm2d-84          [-1, 128, 64, 64]             256
             ReLU-85          [-1, 128, 64, 64]               0
           Conv2d-86          [-1, 128, 64, 64]         147,584
           Conv2d-87          [-1, 128, 64, 64]          41,088
      BatchNorm2d-88          [-1, 128, 64, 64]             256
AdaptiveAvgPool2d-89            [-1, 128, 1, 1]               0
           Linear-90                   [-1, 16]           2,048
             ReLU-91                   [-1, 16]               0
           Linear-92                  [-1, 128]           2,048
          Sigmoid-93                  [-1, 128]               0
Squeeze_Excitation-94          [-1, 128, 64, 64]               0
     ResNet_Block-95          [-1, 128, 64, 64]               0
    Decoder_Block-96          [-1, 128, 64, 64]               0
      BatchNorm2d-97         [-1, 32, 128, 128]              64
             ReLU-98         [-1, 32, 128, 128]               0
           Conv2d-99        [-1, 128, 128, 128]          36,992
       MaxPool2d-100          [-1, 128, 64, 64]               0
     BatchNorm2d-101          [-1, 128, 64, 64]             256
            ReLU-102          [-1, 128, 64, 64]               0
          Conv2d-103          [-1, 128, 64, 64]         147,584
     BatchNorm2d-104          [-1, 128, 64, 64]             256
            ReLU-105          [-1, 128, 64, 64]               0
          Conv2d-106          [-1, 128, 64, 64]         147,584
 Attention_Block-107          [-1, 128, 64, 64]               0
        Upsample-108        [-1, 128, 128, 128]               0
     BatchNorm2d-109        [-1, 160, 128, 128]             320
            ReLU-110        [-1, 160, 128, 128]               0
          Conv2d-111         [-1, 64, 128, 128]          92,224
     BatchNorm2d-112         [-1, 64, 128, 128]             128
            ReLU-113         [-1, 64, 128, 128]               0
          Conv2d-114         [-1, 64, 128, 128]          36,928
          Conv2d-115         [-1, 64, 128, 128]          10,304
     BatchNorm2d-116         [-1, 64, 128, 128]             128
AdaptiveAvgPool2d-117             [-1, 64, 1, 1]               0
          Linear-118                    [-1, 8]             512
            ReLU-119                    [-1, 8]               0
          Linear-120                   [-1, 64]             512
         Sigmoid-121                   [-1, 64]               0
Squeeze_Excitation-122         [-1, 64, 128, 128]               0
    ResNet_Block-123         [-1, 64, 128, 128]               0
   Decoder_Block-124         [-1, 64, 128, 128]               0
     BatchNorm2d-125         [-1, 16, 256, 256]              32
            ReLU-126         [-1, 16, 256, 256]               0
          Conv2d-127         [-1, 64, 256, 256]           9,280
       MaxPool2d-128         [-1, 64, 128, 128]               0
     BatchNorm2d-129         [-1, 64, 128, 128]             128
            ReLU-130         [-1, 64, 128, 128]               0
          Conv2d-131         [-1, 64, 128, 128]          36,928
     BatchNorm2d-132         [-1, 64, 128, 128]             128
            ReLU-133         [-1, 64, 128, 128]               0
          Conv2d-134         [-1, 64, 128, 128]          36,928
 Attention_Block-135         [-1, 64, 128, 128]               0
        Upsample-136         [-1, 64, 256, 256]               0
     BatchNorm2d-137         [-1, 80, 256, 256]             160
            ReLU-138         [-1, 80, 256, 256]               0
          Conv2d-139         [-1, 32, 256, 256]          23,072
     BatchNorm2d-140         [-1, 32, 256, 256]              64
            ReLU-141         [-1, 32, 256, 256]               0
          Conv2d-142         [-1, 32, 256, 256]           9,248
          Conv2d-143         [-1, 32, 256, 256]           2,592
     BatchNorm2d-144         [-1, 32, 256, 256]              64
AdaptiveAvgPool2d-145             [-1, 32, 1, 1]               0
          Linear-146                    [-1, 4]             128
            ReLU-147                    [-1, 4]               0
          Linear-148                   [-1, 32]             128
         Sigmoid-149                   [-1, 32]               0
Squeeze_Excitation-150         [-1, 32, 256, 256]               0
    ResNet_Block-151         [-1, 32, 256, 256]               0
   Decoder_Block-152         [-1, 32, 256, 256]               0
          Conv2d-153         [-1, 16, 256, 256]           4,624
     BatchNorm2d-154         [-1, 16, 256, 256]              32
          Conv2d-155         [-1, 16, 256, 256]           4,624
     BatchNorm2d-156         [-1, 16, 256, 256]              32
          Conv2d-157         [-1, 16, 256, 256]           4,624
     BatchNorm2d-158         [-1, 16, 256, 256]              32
          Conv2d-159         [-1, 16, 256, 256]           4,624
     BatchNorm2d-160         [-1, 16, 256, 256]              32
          Conv2d-161         [-1, 16, 256, 256]             272
            ASPP-162         [-1, 16, 256, 256]               0
          Conv2d-163          [-1, 3, 256, 256]              51
================================================================
Total params: 4,063,027
Trainable params: 4,063,027
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 893.51
Params size (MB): 15.50
Estimated Total Size (MB): 909.76
----------------------------------------------------------------

Check the graph of the model¶

In [ ]:
# draw the computational graph of the model
model_graph1 = draw_graph(model, input_size=[(1, 3, 256, 256)],  expand_nested=True)
model_graph1.visual_graph.render(format='png')
model_graph1.visual_graph
Out[ ]:
No description has been provided for this image

Datagenerator to query the input images and masks¶

In [ ]:
# Use the data generator to load the dataset
class DataGenerator(Dataset):
    def __init__(self, image_list, masks_folder, num_classes=3):
        self.files = image_list
        self.masks_folder = masks_folder
        self.num_classes = num_classes

    # NUMBER OF FILES IN THE DATASET
    def __len__(self):
        return len(self.files)

    # GETTING SINGLE PAIR OF DATA
    def __getitem__(self, idx):
        file_name = self.files[idx].split('/')[-1]
        file_name = file_name[:-4]

        mask_name = './{}/'.format(self.masks_folder) + file_name + '.png'

        img = cv2.imread(self.files[idx], cv2.IMREAD_UNCHANGED)
        if len(img.shape) == 2:
            img = cv2.merge((img, img, img))

        # Read mask as 3-channel for multi-class
        mask = cv2.imread(mask_name, cv2.IMREAD_COLOR)  # Read as BGR
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)     # Convert to RGB

        # Convert RGB mask to class indices (0, 1, 2)
        # Assuming: Black (0,0,0) = background(0), Red-ish = class-1(1), Green-ish = class-2(2)
        mask_class = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.long)

        # Define color thresholds for each class
        # Background: mostly black/dark
        background_mask = (mask.sum(axis=2) < 50)

        # Class 1: red-ish (R > G and R > B)
        class1_mask = (mask[:,:,0] > mask[:,:,1]) & (mask[:,:,0] > mask[:,:,2]) & (mask[:,:,0] > 50)

        # Class 2: green-ish (G > R and G > B)
        class2_mask = (mask[:,:,1] > mask[:,:,0]) & (mask[:,:,1] > mask[:,:,2]) & (mask[:,:,1] > 50)

        mask_class[background_mask] = 0
        mask_class[class1_mask] = 1
        mask_class[class2_mask] = 2

        # Transpose image
        img_transpose = np.transpose(img, (2, 0, 1))

        # Normalize image to [0, 1]
        img_normalized = img_transpose / 255.0

        return torch.FloatTensor(img_normalized), torch.LongTensor(mask_class), str(file_name)


def load_data(image_list, masks_folder, batch_size=2, num_workers=10, shuffle=True, num_classes=3):
    dataset = DataGenerator(image_list, masks_folder=masks_folder, num_classes=num_classes)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
    return data_loader

Define the Dice-coefficient, Intersection over union metrics and the Categorical Cross-Entropy Loss function¶

In [ ]:
# Dice coefficient calculation for multi-class with class indices
def dice_coefficient(pred, target, num_classes, smooth=1e-6):
    """
    Calculate dice coefficient for multi-class segmentation
    pred: predicted probabilities [B, C, H, W] (after softmax)
    target: ground truth class indices [B, H, W]
    """
    dice_scores = []

    # Convert predictions to class predictions
    pred_classes = torch.argmax(pred, dim=1)  # [B, H, W]

    for class_idx in range(num_classes):
        pred_class = (pred_classes == class_idx).float()
        target_class = (target == class_idx).float()

        intersection = (pred_class * target_class).sum(dim=(1, 2))
        union = pred_class.sum(dim=(1, 2)) + target_class.sum(dim=(1, 2))

        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.mean())

    return torch.stack(dice_scores), torch.stack(dice_scores).mean()


# IoU calculation for multi-class with class indices
def iou_coefficient(pred, target, num_classes, smooth=1e-6):
    """
    Calculate IoU for multi-class segmentation
    pred: predicted probabilities [B, C, H, W] (after softmax)
    target: ground truth class indices [B, H, W]
    """
    iou_scores = []

    # Convert predictions to class predictions
    pred_classes = torch.argmax(pred, dim=1)  # [B, H, W]

    for class_idx in range(num_classes):
        pred_class = (pred_classes == class_idx).float()
        target_class = (target == class_idx).float()

        intersection = (pred_class * target_class).sum(dim=(1, 2))
        union = pred_class.sum(dim=(1, 2)) + target_class.sum(dim=(1, 2)) - intersection

        iou = (intersection + smooth) / (union + smooth)
        iou_scores.append(iou.mean())

    return torch.stack(iou_scores), torch.stack(iou_scores).mean()


# Categorical Cross Entropy Loss with optional class weights
class CategoryCrossEntropyLoss(nn.Module):
    def __init__(self, weight=None):
        super(CategoryCrossEntropyLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss(weight=weight)

    def forward(self, pred, target):
        """
        pred: [B, C, H, W] - raw logits from model
        target: [B, H, W] - class indices
        """
        return self.criterion(pred, target)

We have three classes, with 3 colors, here, we define the overlay code to be used in future¶

In [ ]:
# Create color map for visualization
def create_color_map():
    """Create color map for 3 classes: background, class-1, class-2"""
    color_map = np.array([
        [0, 0, 0],       # Background - Black
        [0, 0, 128],     # Class 1 - dark Red
        [0, 128, 0],     # Class 2 - dark Green
    ], dtype=np.uint8)
    return color_map


def apply_color_map(mask, color_map):
    """Apply color map to single channel mask"""
    h, w = mask.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx, color in enumerate(color_map):
        colored_mask[mask == class_idx] = color
    return colored_mask


def create_overlay(image, mask, alpha=0.6):
    """Create overlay of image and colored mask"""
    if len(image.shape) == 3 and image.shape[2] == 3:
        # Image is already RGB
        overlay = cv2.addWeighted(image.astype(np.uint8), 1-alpha, mask.astype(np.uint8), alpha, 0)
    else:
        # Convert grayscale to RGB if needed
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        overlay = cv2.addWeighted(image.astype(np.uint8), 1-alpha, mask.astype(np.uint8), alpha, 0)
    return overlay

Define the function for loading the images by reading it one by one¶

In [ ]:
# sanity check
def get_image_address(image_data_folder, subfolder_name):
    total_imgs = glob.glob(image_data_folder+subfolder_name+"/*.png")
    error_counter = 0
    TOTAL_COUNT = len(total_imgs)
    image_address_list = []
    for image_name in total_imgs:
        try:
            img = cv2.imread(image_name, cv2.IMREAD_UNCHANGED)
            x = img.shape
            image_address_list.append(image_name)
        except:
            print("Image not found : ",image_name)
            error_counter+=1
    print("Number of Files not found : ", error_counter)
    print("Total Number of Files found : ", len(image_address_list))
    return image_address_list
In [ ]:
# save checkpoint in pytorch
def save_ckp(checkpoint, checkpoint_path, save_after_epochs):
    if checkpoint['epoch'] % save_after_epochs == 0:
        torch.save(checkpoint, checkpoint_path)


# load checkpoint in pytorch
def load_ckp(checkpoint_path, model, model_opt):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    model_opt.load_state_dict(checkpoint['optimizer'])
    return model, model_opt, checkpoint['epoch']

Here, define the function for training the model, some of the core components of this section are:¶

  • model.train() -> sets the model to update the weights (training mode)
  • optimizer.zero_grad() -> clear the gradients from the previous pass of the computational graph
  • Loss computation in the model using predicted masks and ground truth masks
  • loss.backward() -> do the backward pass
  • optimizer.step() -> update the weights

The train and val functions save the training and validation metrics (i.e., average Dice and IoU of the training and validation dataset) as the training progresses.¶

In [ ]:
def train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes=3):
    print("\n\n---------------------------------------------------------------------------------------------------------------\n")

    progress_bar = tqdm(enumerate(train_loader))
    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0

    model.train()

    # Initialize loss function with class weights (optional - adjust based on your class imbalance)
    class_weights = torch.FloatTensor([0.5, 1.0, 1.0]).to(device)  # Less weight for background
    criterion = CategoryCrossEntropyLoss(weight=class_weights)

    for step, (inp__, gt__, file_name) in progress_bar:
        #TRANSFERRING DATA TO DEVICE
        inp__ = inp__.to(device)
        gt__ = gt__.to(device)

        # clear the gradient
        optimizer.zero_grad()

        #GETTING THE PREDICTED IMAGE (raw logits)
        pred_logits = model.forward(inp__)

        # Apply softmax for probability calculation (for metrics only)
        pred_probs = torch.softmax(pred_logits, dim=1)

        #CALCULATING LOSS (using raw logits)
        loss = criterion(pred_logits, gt__)

        #LOSS TAKEN INTO CONSIDERATION
        total_loss += loss.item()

        # Calculate metrics using probabilities
        with torch.no_grad():
            dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
            iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
            total_dice += mean_dice.item()
            total_iou += mean_iou.item()

        #BACKPROPAGATING THE LOSS
        loss.backward()
        optimizer.step()

        #DISPLAYING THE LOSS
        progress_bar.set_description("Epoch: {} - Loss: {:.4f} - Dice: {:.4f} - IoU: {:.4f}".format(
            epoch, loss.item(), mean_dice.item(), mean_iou.item()))

    avg_loss = total_loss / len(train_loader)
    avg_dice = total_dice / len(train_loader)
    avg_iou = total_iou / len(train_loader)

    with open("{}/train_logs.txt".format(hist_folder_name), "a") as text_file:
        text_file.write("{} {:.6f} {:.6f} {:.6f}\n".format(epoch, avg_loss, avg_dice, avg_iou))

    print(Fore.GREEN+"Training Epoch: {} | Loss: {:.4f} | Dice: {:.4f} | IoU: {:.4f}".format(
        epoch, avg_loss, avg_dice, avg_iou)+Style.RESET_ALL)

    return model, optimizer

Here, we will define the validation epoch, we will set the model to model.eval() mode, so that there is no updation of weights, no loss computation is done here¶

In [ ]:
def val_epoch(val_loader, model, optimizer, epoch, hist_folder_name, num_classes=3):
    model.eval()
    progress_bar = tqdm(enumerate(val_loader))
    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0

    class_weights = torch.FloatTensor([0.5, 1.0, 1.0]).to(device)
    criterion = CategoryCrossEntropyLoss(weight=class_weights)

    for step, (inp__, gt__, file_name) in progress_bar:
        inp__ = inp__.to(device)
        gt__ = gt__.to(device)

        #PREDICTED IMAGE
        pred_logits = model.forward(inp__)
        pred_probs = torch.softmax(pred_logits, dim=1)

        #CALCULATING LOSSES
        loss = criterion(pred_logits, gt__)
        total_loss += loss.item()

        # Calculate metrics
        dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
        iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
        total_dice += mean_dice.item()
        total_iou += mean_iou.item()

        progress_bar.set_description("Val Epoch: {} - Loss: {:.4f} - Dice: {:.4f} - IoU: {:.4f}".format(
            epoch, loss.item(), mean_dice.item(), mean_iou.item()))

    avg_loss = total_loss / len(val_loader)
    avg_dice = total_dice / len(val_loader)
    avg_iou = total_iou / len(val_loader)

    with open("{}/val_logs.txt".format(hist_folder_name), "a") as text_file:
        text_file.write("{} {:.6f} {:.6f} {:.6f}\n".format(epoch, avg_loss, avg_dice, avg_iou))

    print(Fore.BLUE+"Validation Epoch: {} | Loss: {:.4f} | Dice: {:.4f} | IoU: {:.4f}".format(
        epoch, avg_loss, avg_dice, avg_iou)+Style.RESET_ALL)

Function for the test epoch -> for sanity check¶

  • Here we will test the model on the test dataset after the model has been trained, we will dump some of the images and see how the model is performing by overlaying the predicted masks on the images
In [ ]:
def test_epoch(test_loader, model, optimizer, epoch, inf_folder_name, hist_folder_name, num_classes=3):
    model.eval()
    progress_bar = tqdm(enumerate(test_loader))
    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0

    #SETTING THE NUMBER OF IMAGES TO CHECK AFTER EACH ITERATION
    no_img_to_write = 20

    class_weights = torch.FloatTensor([0.5, 1.0, 1.0]).to(device)
    criterion = CategoryCrossEntropyLoss(weight=class_weights)

    # Get color map for visualization
    color_map = create_color_map()

    for step, (inp__, gt__, file_name) in progress_bar:
        inp__ = inp__.to(device)
        gt__ = gt__.to(device)

        #PREDICTED IMAGE
        pred_logits = model.forward(inp__)
        pred_probs = torch.softmax(pred_logits, dim=1)

        #CALCULATING LOSSES
        loss = criterion(pred_logits, gt__)
        total_loss += loss.item()

        # Calculate metrics
        dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
        iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)
        total_dice += mean_dice.item()
        total_iou += mean_iou.item()

        progress_bar.set_description("Test Epoch: {} - Loss: {:.4f} - Dice: {:.4f} - IoU: {:.4f}".format(
            epoch, loss.item(), mean_dice.item(), mean_iou.item()))

        #WRITING THE IMAGES INTO THE SPECIFIED DIRECTORY
        if(step < no_img_to_write):

            pred_classes = torch.argmax(pred_probs, dim=1)  # Get class predictions
            p_img = pred_classes.cpu().numpy()  # [B, H, W]
            gt_img = gt__.cpu().numpy()  # [B, H, W]
            inp_img = inp__.cpu().numpy()  # [B, C, H, W]

            #FOLDER PATH TO WRITE THE INFERENCES
            inference_folder = "{}".format(inf_folder_name)
            if not os.path.isdir(inference_folder):
                os.mkdir(inference_folder)

            print("\n Saving inferences at epoch === ",epoch)

            # Save inference images for multi-class with overlays
            for batch_idx, (p_image_loop, gt_img_loop, inp_img_loop) in enumerate(zip(p_img, gt_img, inp_img)):

                # Prepare input image for overlay
                inp_img_display = np.transpose(inp_img_loop, (1, 2, 0))  # [H, W, C]
                inp_img_display = (inp_img_display * 255.0).astype(np.uint8)

                # 1. Save raw input image
                cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_input.png"),
                           cv2.cvtColor(inp_img_display, cv2.COLOR_RGB2BGR))

                # 2. Create and save colored prediction mask
                pred_colored = apply_color_map(p_image_loop, color_map)
                cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_pred_mask.png"),
                           cv2.cvtColor(pred_colored, cv2.COLOR_RGB2BGR))

                # 3. Create and save colored ground truth mask
                gt_colored = apply_color_map(gt_img_loop, color_map)
                cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_gt_mask.png"),
                           cv2.cvtColor(gt_colored, cv2.COLOR_RGB2BGR))

                # 4. Create and save prediction overlay
                pred_overlay = create_overlay(inp_img_display, pred_colored, alpha=0.4)
                cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_pred_overlay.png"),
                           cv2.cvtColor(pred_overlay, cv2.COLOR_RGB2BGR))

                # 5. Create and save ground truth overlay
                gt_overlay = create_overlay(inp_img_display, gt_colored, alpha=0.4)
                cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_gt_overlay.png"),
                           cv2.cvtColor(gt_overlay, cv2.COLOR_RGB2BGR))

                # 6. Create side-by-side comparison
                comparison = np.hstack([
                    inp_img_display,
                    gt_colored,
                    pred_colored,
                    gt_overlay,
                    pred_overlay
                ])
                cv2.imwrite(os.path.join(inference_folder, str(file_name[batch_idx])+"_comparison.png"),
                           cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))

    avg_loss = total_loss / len(test_loader)
    avg_dice = total_dice / len(test_loader)
    avg_iou = total_iou / len(test_loader)

    with open("{}/test_logs.txt".format(hist_folder_name), "a") as text_file:
        text_file.write("{} {:.6f} {:.6f} {:.6f}\n".format(epoch, avg_loss, avg_dice, avg_iou))

    print(Fore.RED+"Test Epoch: {} | Loss: {:.4f} | Dice: {:.4f} | IoU: {:.4f}".format(
        epoch, avg_loss, avg_dice, avg_iou)+Style.RESET_ALL)
    print("---------------------------------------------------------------------------------------------------------------")

Master epoch to control the train, validation and test epoch. Here, we can save the checkpoint after each epoch or after a specified number of epochs, resume training or start training from scratch. Also, this section creates all the folders for dumping the inferences and history/metrics that are recorded during training¶

In [ ]:
def train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume, model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes=3):
    checkpoint_folder_name = "checkpoint"
    Path(checkpoint_folder_name).mkdir(parents=True, exist_ok=True)
    Path(inf_folder_name).mkdir(parents=True, exist_ok=True)
    Path(hist_folder_name).mkdir(parents=True, exist_ok=True)

    epoch = 0

    #PATH TO SAVE THE CHECKPOINT
    checkpoint_path = "checkpoint/{}_{}.pt".format(model_name,epoch)

    #IF TRAINING IS TO RESUMED FROM A CERTAIN CHECKPOINT
    if resume:
        model, optimizer, epoch = load_ckp(
            checkpoint_path, model, optimizer)

    while epoch <= n_epoch:
        checkpoint_path = "checkpoint/{}_{}.pt".format(model_name,epoch)
        epoch += 1
        model, optimizer = train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes)

        #CHECKPOINT CREATION
        checkpoint = {'epoch': epoch+1, 'state_dict': model.state_dict(),
                      'optimizer': optimizer.state_dict()}

        #CHECKPOINT SAVING
        save_ckp(checkpoint, checkpoint_path, save_after_epochs)
        print(Fore.BLACK+"Checkpoint Saved"+Style.RESET_ALL)

        with torch.no_grad():
            val_epoch(val_loader, model, optimizer, epoch, hist_folder_name, num_classes)

    print("************************ Final Test Epoch *****************************")

    with torch.no_grad():
        test_epoch(test_loader, model, optimizer, epoch, inf_folder_name, hist_folder_name, num_classes)

The main function, where all the training is done. We can pass all the arguments using argsparser, or define a class to store the hyper-parameters and values which will be necessary for training. We load the train, validation and test image-mask pairs and the dataloaders and then call the master epoch for performing training, validation and testing.¶

In [ ]:
def main():
    # Set your parameters directly here for Jupyter notebook
    class Args:
        def __init__(self):
            # Training parameters
            self.lr = "0.001"
            self.batch_size = "4"
            self.num_epochs = "2"
            self.num_classes = 3

            # Dataset folders
            self.train_images = "train_images"
            self.val_images = "val_images"
            self.test_images = "test_images"

            self.train_masks = "train_masks"
            self.val_masks = "val_masks"
            self.test_masks = "test_masks"

            # Output folders
            self.history_folder_name = "history"
            self.inference_folder_name = "inference"
            self.chkpt_name = "multiclass_model"
            self.save_after_epoch = "1"

    args = Args()

    train_image_address_list = get_image_address("./", args.train_images)
    print("Total Number of Training Images : ", len(train_image_address_list))

    val_image_address_list = get_image_address("./", args.val_images)
    test_image_address_list = get_image_address("./", args.test_images)

    train_masks = args.train_masks
    val_masks = args.val_masks
    test_masks = args.test_masks
    save_after_epochs = int(args.save_after_epoch)
    num_classes = args.num_classes

    # CREATING THE TRAIN LOADER
    train_loader = load_data(
        train_image_address_list, masks_folder=train_masks, batch_size=int(args.batch_size),
        num_workers=8, shuffle=True, num_classes=num_classes)

    val_loader = load_data(
        val_image_address_list, masks_folder=val_masks, batch_size=int(args.batch_size),
        num_workers=8, shuffle=True, num_classes=num_classes)

    # CREATING THE TEST LOADER
    test_loader = load_data(
        test_image_address_list, masks_folder=test_masks, batch_size=1,
        num_workers=8, shuffle=False, num_classes=num_classes)

    # CALLING THE MODEL - Make sure your model outputs num_classes channels
    model = ResUnetPlusPlus(input_ch=3, n_class=num_classes)  # Change this to match your model
    # model = nn.DataParallel(model)
    model = model.to(device)

    # DEFINING THE OPTIMIZER
    optimizer = optim.Adam(
        [p for p in model.parameters() if p.requires_grad], lr=float(args.lr), weight_decay=5e-4)

    n_epoch = int(args.num_epochs)

    # INDICATOR VARIABLE TO RESUME TRAINING OR START AFRESH
    resume = False
    model_name = args.chkpt_name
    inf_folder_name = args.inference_folder_name
    hist_folder_name = args.history_folder_name

    train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume,
                   model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes)

if __name__ == "__main__":
    main()
Number of Files not found :  0
Total Number of Files found :  4723
Total Number of Training Images :  4723
Number of Files not found :  0
Total Number of Files found :  787
Number of Files not found :  0
Total Number of Files found :  2362
/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(

---------------------------------------------------------------------------------------------------------------

/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
Epoch: 1 - Loss: 0.9149 - Dice: 0.5427 - IoU: 0.4689: : 54it [00:10,  4.93it/s]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
/tmp/ipython-input-893361213.py in <cell line: 0>()
     73 
     74 if __name__ == "__main__":
---> 75     main()

/tmp/ipython-input-893361213.py in main()
     69     hist_folder_name = args.history_folder_name
     70 
---> 71     train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume,
     72                    model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes)
     73 

/tmp/ipython-input-3502321866.py in train_val_test(train_loader, val_loader, test_loader, model, optimizer, n_epoch, resume, model_name, inf_folder_name, hist_folder_name, save_after_epochs, num_classes)
     18         checkpoint_path = "checkpoint/{}_{}.pt".format(model_name,epoch)
     19         epoch += 1
---> 20         model, optimizer = train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes)
     21 
     22         #CHECKPOINT CREATION

/tmp/ipython-input-4030073634.py in train_epoch(train_loader, model, optimizer, epoch, hist_folder_name, num_classes)
     31 
     32         #LOSS TAKEN INTO CONSIDERATION
---> 33         total_loss += loss.item()
     34 
     35         # Calculate metrics using probabilities

KeyboardInterrupt: 

Download the pre-trained models and the history (which is trained for 1000 epochs) for loading the best models¶

In [ ]:
! wget https://jimut123.github.io/talks/2025_dlnlp/multiclass_model_998.pt
! wget https://jimut123.github.io/talks/2025_dlnlp/history/train_logs.txt
! wget https://jimut123.github.io/talks/2025_dlnlp/history/val_logs.txt
! wget https://jimut123.github.io/talks/2025_dlnlp/history/test_logs.txt
--2025-09-14 10:56:19--  https://jimut123.github.io/talks/2025_dlnlp/multiclass_model_998.pt
Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49032778 (47M) [application/octet-stream]
Saving to: ‘multiclass_model_998.pt’

multiclass_model_99 100%[===================>]  46.76M   273MB/s    in 0.2s    

2025-09-14 10:56:22 (273 MB/s) - ‘multiclass_model_998.pt’ saved [49032778/49032778]

--2025-09-14 10:56:22--  https://jimut123.github.io/talks/2025_dlnlp/history/train_logs.txt
Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30957 (30K) [text/plain]
Saving to: ‘train_logs.txt’

train_logs.txt      100%[===================>]  30.23K  --.-KB/s    in 0.001s  

2025-09-14 10:56:23 (47.1 MB/s) - ‘train_logs.txt’ saved [30957/30957]

--2025-09-14 10:56:23--  https://jimut123.github.io/talks/2025_dlnlp/history/val_logs.txt
Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30957 (30K) [text/plain]
Saving to: ‘val_logs.txt’

val_logs.txt        100%[===================>]  30.23K  --.-KB/s    in 0.001s  

2025-09-14 10:56:23 (52.9 MB/s) - ‘val_logs.txt’ saved [30957/30957]

--2025-09-14 10:56:23--  https://jimut123.github.io/talks/2025_dlnlp/history/test_logs.txt
Resolving jimut123.github.io (jimut123.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to jimut123.github.io (jimut123.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 32 [text/plain]
Saving to: ‘test_logs.txt’

test_logs.txt       100%[===================>]      32  --.-KB/s    in 0s      

2025-09-14 10:56:24 (1.90 MB/s) - ‘test_logs.txt’ saved [32/32]

In [ ]:
def load_log_file(filename):
    """Load log file and return arrays for epoch, loss, dice, iou"""
    data = np.loadtxt(filename)
    epochs = data[:, 0]
    loss = data[:, 1]
    dice = data[:, 2]
    iou = data[:, 3]
    return epochs, loss, dice, iou

def plot_metrics(train_epochs, train_loss, train_dice, train_iou,
                val_epochs, val_loss, val_dice, val_iou,
                save_path='training_history.png'):
    """Plot training and validation metrics"""

    # Create figure with subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Training History Metrics', fontsize=16, fontweight='bold', y=0.98)

    # Plot Loss
    ax1.plot(train_epochs, train_loss, 'b-', linewidth=2, label='Training Loss', alpha=0.8)
    ax1.plot(val_epochs, val_loss, 'r-', linewidth=2, label='Validation Loss', alpha=0.8)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Categorical CrossEntropy Loss')
    ax1.set_title('Loss Curve')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot Dice Coefficient
    ax2.plot(train_epochs, train_dice, 'b-', linewidth=2, label='Training Dice', alpha=0.8)
    ax2.plot(val_epochs, val_dice, 'r-', linewidth=2, label='Validation Dice', alpha=0.8)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Coefficient')
    ax2.set_title('Dice Coefficient')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Plot IoU
    ax3.plot(train_epochs, train_iou, 'b-', linewidth=2, label='Training IoU', alpha=0.8)
    ax3.plot(val_epochs, val_iou, 'r-', linewidth=2, label='Validation IoU', alpha=0.8)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('IoU')
    ax3.set_title('Intersection over Union (IoU)')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # Plot all metrics together for comparison
    ax4.plot(train_epochs, train_loss, 'b-', linewidth=2, label='Train Loss', alpha=0.8)
    ax4.plot(train_epochs, train_dice, 'g-', linewidth=2, label='Train Dice', alpha=0.8)
    ax4.plot(train_epochs, train_iou, 'm-', linewidth=2, label='Train IoU', alpha=0.8)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Metric Value')
    ax4.set_title('Training Metrics Comparison')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    plt.close()


    print(f"Plot saved as {save_path}")
In [ ]:
# Load training and validation logs
try:
    train_epochs, train_loss, train_dice, train_iou = load_log_file('train_logs.txt')
    val_epochs, val_loss, val_dice, val_iou = load_log_file('val_logs.txt')

    # Plot and save the metrics
    plot_metrics(train_epochs, train_loss, train_dice, train_iou,
                val_epochs, val_loss, val_dice, val_iou,
                save_path='training_history.png')

except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Please make sure both 'train_logs.txt' and 'val_logs.txt' exist in the current directory.")
except Exception as e:
    print(f"Error loading files: {e}")
No description has been provided for this image
Plot saved as training_history.png

In this function, we go through the image-samples from the test folder and dump them in test_dumps folder. We also store the overlays of the predicted masks and the ground-truth masks in the test images and save them. Also returns the list of Dice and IoU metrics to be used later¶

In [ ]:
def test_dump_with_checkpoint(test_loader, checkpoint_path, dump_folder="test_dumps", num_classes=3):
    """
    Load checkpoint and dump test images with predictions and overlays
    """
    # Create dump folder
    Path(dump_folder).mkdir(parents=True, exist_ok=True)

    # Load model and checkpoint
    model = ResUnetPlusPlus(input_ch=3, n_class=num_classes)
    model = model.to(device)

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])

    print(f"Loaded checkpoint from epoch: {checkpoint.get('epoch', 'Unknown')}")

    model.eval()
    progress_bar = tqdm(enumerate(test_loader))

    # Get color map for visualization
    color_map = create_color_map()

    # Store metrics for each image
    image_metrics = []

    with torch.no_grad():
        for step, (inp__, gt__, file_name) in progress_bar:
            inp__ = inp__.to(device)
            gt__ = gt__.to(device)

            # Get predictions
            pred_logits = model.forward(inp__)
            pred_probs = torch.softmax(pred_logits, dim=1)
            pred_classes = torch.argmax(pred_probs, dim=1)

            # Calculate metrics for this batch
            dice_per_class, mean_dice = dice_coefficient(pred_probs, gt__, num_classes)
            iou_per_class, mean_iou = iou_coefficient(pred_probs, gt__, num_classes)

            # Convert to numpy
            p_img = pred_classes.cpu().numpy()
            gt_img = gt__.cpu().numpy()
            inp_img = inp__.cpu().numpy()

            progress_bar.set_description(f"Processing image {step+1} - Dice: {mean_dice.item():.4f}")

            # Save images for each item in batch
            for batch_idx, (p_image_loop, gt_img_loop, inp_img_loop) in enumerate(zip(p_img, gt_img, inp_img)):

                # Store metrics for this image
                current_dice_scores = dice_per_class.cpu().numpy()
                current_iou_scores = iou_per_class.cpu().numpy()

                image_info = {
                    'filename': str(file_name[batch_idx]),
                    'dice_class1': current_dice_scores[1],  # Class 1 dice
                    'dice_class2': current_dice_scores[2],  # Class 2 dice
                    'dice_mean': mean_dice.item(),
                    'iou_class1': current_iou_scores[1],
                    'iou_class2': current_iou_scores[2],
                    'iou_mean': mean_iou.item()
                }
                image_metrics.append(image_info)

                # Prepare input image for overlay
                inp_img_display = np.transpose(inp_img_loop, (1, 2, 0))
                inp_img_display = (inp_img_display * 255.0).astype(np.uint8)

                # Create colored masks
                pred_colored = apply_color_map(p_image_loop, color_map)
                gt_colored = apply_color_map(gt_img_loop, color_map)

                # Create overlays
                pred_overlay = create_overlay(inp_img_display, pred_colored, alpha=0.4)
                gt_overlay = create_overlay(inp_img_display, gt_colored, alpha=0.4)

                # Save all images
                base_name = str(file_name[batch_idx])
                cv2.imwrite(os.path.join(dump_folder, f"{base_name}_input.png"),
                           cv2.cvtColor(inp_img_display, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(dump_folder, f"{base_name}_pred_mask.png"),
                           cv2.cvtColor(pred_colored, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(dump_folder, f"{base_name}_gt_mask.png"),
                           cv2.cvtColor(gt_colored, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(dump_folder, f"{base_name}_pred_overlay.png"),
                           cv2.cvtColor(pred_overlay, cv2.COLOR_RGB2BGR))
                cv2.imwrite(os.path.join(dump_folder, f"{base_name}_gt_overlay.png"),
                           cv2.cvtColor(gt_overlay, cv2.COLOR_RGB2BGR))

                # Create side-by-side comparison
                comparison = np.hstack([inp_img_display, gt_colored, pred_colored, gt_overlay, pred_overlay])
                cv2.imwrite(os.path.join(dump_folder, f"{base_name}_comparison.png"),
                           cv2.cvtColor(comparison, cv2.COLOR_RGB2BGR))

    print(f"\nCompleted test dump. Images saved in: {dump_folder}")
    return image_metrics

This function calculates the mean and standard deviation of predicted masks based on the ground truth masks¶

In [ ]:
def calculate_class_metrics(image_metrics):
    """
    Calculate mean ± std for dice coefficients of class 1 and class 2
    """
    dice_class1 = [img['dice_class1'] for img in image_metrics]
    dice_class2 = [img['dice_class2'] for img in image_metrics]

    # Calculate statistics
    dice_c1_mean = np.mean(dice_class1)
    dice_c1_std = np.std(dice_class1)
    dice_c2_mean = np.mean(dice_class2)
    dice_c2_std = np.std(dice_class2)

    print("\n" + "="*60)
    print("DICE COEFFICIENT RESULTS")
    print("="*60)
    print(f"Class 1 Dice: {dice_c1_mean:.4f} ± {dice_c1_std:.4f}")
    print(f"Class 2 Dice: {dice_c2_mean:.4f} ± {dice_c2_std:.4f}")
    print("="*60)

    # Also calculate overall metrics
    overall_dice = [img['dice_mean'] for img in image_metrics]
    iou_class1 = [img['iou_class1'] for img in image_metrics]
    iou_class2 = [img['iou_class2'] for img in image_metrics]
    overall_iou = [img['iou_mean'] for img in image_metrics]

    print(f"Overall Dice: {np.mean(overall_dice):.4f} ± {np.std(overall_dice):.4f}\n\n")
    print(f"Class 1 IoU: {np.mean(iou_class1):.4f} ± {np.std(iou_class1):.4f}")
    print(f"Class 2 IoU: {np.mean(iou_class2):.4f} ± {np.std(iou_class2):.4f}")
    print(f"Overall IoU: {np.mean(overall_iou):.4f} ± {np.std(overall_iou):.4f}")

    return {
        'dice_class1': {'mean': dice_c1_mean, 'std': dice_c1_std},
        'dice_class2': {'mean': dice_c2_mean, 'std': dice_c2_std},
        'overall_dice': {'mean': np.mean(overall_dice), 'std': np.std(overall_dice)}
    }

This function displays the best performing image-prediction overlays and the worst performing image-prediction overlays based on the prediction of model.¶

In [ ]:
def display_best_worst_images(image_metrics, dump_folder, top_k=5, criterion='dice_mean'):
    """
    Display the best and worst performing images based on specified criterion
    """
    # Sort images by performance
    sorted_metrics = sorted(image_metrics, key=lambda x: x[criterion])

    worst_images = sorted_metrics[:top_k]  # Lowest scores
    best_images = sorted_metrics[-top_k:]  # Highest scores
    best_images.reverse()  # Show best first

    print(f"\n" + "="*80)
    print(f"TOP {top_k} BEST PERFORMING IMAGES (by {criterion})")
    print("="*80)

    fig, axes = plt.subplots(top_k, 5, figsize=(20, 4*top_k))
    if top_k == 1:
        axes = axes.reshape(1, -1)

    for idx, img_info in enumerate(best_images):
        filename = img_info['filename']
        score = img_info[criterion]

        print(f"{idx+1}. {filename} - {criterion}: {score:.4f}")
        print(f"   Class 1 Dice: {img_info['dice_class1']:.4f}, Class 2 Dice: {img_info['dice_class2']:.4f}")

        # Load and display images
        input_img = cv2.imread(os.path.join(dump_folder, f"{filename}_input.png"))
        gt_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_mask.png"))
        pred_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_mask.png"))
        gt_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_overlay.png"))
        pred_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_overlay.png"))

        # Convert BGR to RGB for matplotlib
        images = [input_img, gt_mask, pred_mask, gt_overlay, pred_overlay]
        titles = ['Input', 'GT Mask', 'Pred Mask', 'GT Overlay', 'Pred Overlay']

        for j, (img, title) in enumerate(zip(images, titles)):
            if img is not None:
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                axes[idx, j].imshow(img_rgb)
                axes[idx, j].set_title(f'{title}\n{criterion}: {score:.3f}')
                axes[idx, j].axis('off')

    plt.suptitle(f'Best {top_k} Images by {criterion}', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.show()

    print(f"\n" + "="*80)
    print(f"TOP {top_k} WORST PERFORMING IMAGES (by {criterion})")
    print("="*80)

    fig, axes = plt.subplots(top_k, 5, figsize=(20, 4*top_k))
    if top_k == 1:
        axes = axes.reshape(1, -1)

    for idx, img_info in enumerate(worst_images):
        filename = img_info['filename']
        score = img_info[criterion]

        print(f"{idx+1}. {filename} - {criterion}: {score:.4f}")
        print(f"   Class 1 Dice: {img_info['dice_class1']:.4f}, Class 2 Dice: {img_info['dice_class2']:.4f}")

        # Load and display images
        input_img = cv2.imread(os.path.join(dump_folder, f"{filename}_input.png"))
        gt_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_mask.png"))
        pred_mask = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_mask.png"))
        gt_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_gt_overlay.png"))
        pred_overlay = cv2.imread(os.path.join(dump_folder, f"{filename}_pred_overlay.png"))

        # Convert BGR to RGB for matplotlib
        images = [input_img, gt_mask, pred_mask, gt_overlay, pred_overlay]
        titles = ['Input', 'GT Mask', 'Pred Mask', 'GT Overlay', 'Pred Overlay']

        for j, (img, title) in enumerate(zip(images, titles)):
            if img is not None:
                img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                axes[idx, j].imshow(img_rgb)
                axes[idx, j].set_title(f'{title}\n{criterion}: {score:.3f}')
                axes[idx, j].axis('off')

    plt.suptitle(f'Worst {top_k} Images by {criterion}', fontsize=16, y=0.98)
    plt.tight_layout()
    plt.show()

Shows the class wise and overall Dice and IoU of the predicted masks¶

In [ ]:
# Parameters - load from the saved checkpoint and mention the dump folder name
checkpoint_path = "multiclass_model_998.pt"
dump_folder = "test_dumps"
num_classes = 3

# Get test image addresses
test_image_address_list = get_image_address("./", "test_images")
print(f"Total Number of Test Images: {len(test_image_address_list)}")

# Create test loader
test_loader = load_data(
    test_image_address_list, masks_folder="test_masks", batch_size=1,
    num_workers=8, shuffle=False, num_classes=num_classes)

# Run test dump
print("Starting test dump with checkpoint loading...")
image_metrics = test_dump_with_checkpoint(
    test_loader, checkpoint_path, dump_folder, num_classes)

# Calculate and display metrics
print("Calculating class-wise metrics...")
metrics_summary = calculate_class_metrics(image_metrics)
Number of Files not found :  0
Total Number of Files found :  2362
Total Number of Test Images: 2362
Starting test dump with checkpoint loading...
Loaded checkpoint from epoch: 1000
Processing image 2362 - Dice: 0.7259: : 2362it [02:44, 14.38it/s]
Completed test dump. Images saved in: test_dumps
Calculating class-wise metrics...

============================================================
DICE COEFFICIENT RESULTS
============================================================
Class 1 Dice: 0.8544 ± 0.2623
Class 2 Dice: 0.9306 ± 0.2090
============================================================
Overall Dice: 0.9227 ± 0.1426


Class 1 IoU: 0.8077 ± 0.2738
Class 2 IoU: 0.9136 ± 0.2248
Overall IoU: 0.8964 ± 0.1586
In [ ]:
# Display best and worst images
print("Displaying best and worst performing images...")
display_best_worst_images(image_metrics, dump_folder, top_k=3, criterion='dice_mean')
Displaying best and worst performing images...

================================================================================
TOP 3 BEST PERFORMING IMAGES (by dice_mean)
================================================================================
1. 0129_5 - dice_mean: 0.9991
   Class 1 Dice: 1.0000, Class 2 Dice: 0.9983
2. 0654_1 - dice_mean: 0.9989
   Class 1 Dice: 1.0000, Class 2 Dice: 0.9987
3. 0573_0 - dice_mean: 0.9988
   Class 1 Dice: 1.0000, Class 2 Dice: 0.9977
No description has been provided for this image
================================================================================
TOP 3 WORST PERFORMING IMAGES (by dice_mean)
================================================================================
1. 0136_5 - dice_mean: 0.2355
   Class 1 Dice: 0.0000, Class 2 Dice: 0.0500
2. 0136_3 - dice_mean: 0.2536
   Class 1 Dice: 0.0000, Class 2 Dice: 0.0734
3. 0136_2 - dice_mean: 0.2782
   Class 1 Dice: 0.0000, Class 2 Dice: 0.1701
No description has been provided for this image
In [ ]:
# display based on specific class performance
print("\nDisplaying based on Class 1 Dice performance...")
display_best_worst_images(image_metrics, dump_folder, top_k=3, criterion='dice_class1')
Displaying based on Class 1 Dice performance...

================================================================================
TOP 3 BEST PERFORMING IMAGES (by dice_class1)
================================================================================
1. 0347_1 - dice_class1: 1.0000
   Class 1 Dice: 1.0000, Class 2 Dice: 0.9886
2. 0577_5 - dice_class1: 1.0000
   Class 1 Dice: 1.0000, Class 2 Dice: 0.9909
3. 0283_4 - dice_class1: 1.0000
   Class 1 Dice: 1.0000, Class 2 Dice: 0.9908
No description has been provided for this image
================================================================================
TOP 3 WORST PERFORMING IMAGES (by dice_class1)
================================================================================
1. 0394_4 - dice_class1: 0.0000
   Class 1 Dice: 0.0000, Class 2 Dice: 0.7643
2. 0024_5 - dice_class1: 0.0000
   Class 1 Dice: 0.0000, Class 2 Dice: 0.8337
3. 0653_4 - dice_class1: 0.0000
   Class 1 Dice: 0.0000, Class 2 Dice: 0.9222
No description has been provided for this image
In [ ]:
print("\nDisplaying based on Class 2 Dice performance...")
display_best_worst_images(image_metrics, dump_folder, top_k=3, criterion='dice_class2')
Displaying based on Class 2 Dice performance...

================================================================================
TOP 3 BEST PERFORMING IMAGES (by dice_class2)
================================================================================
1. 0202_5 - dice_class2: 1.0000
   Class 1 Dice: 0.9678, Class 2 Dice: 1.0000
2. 0173_4 - dice_class2: 1.0000
   Class 1 Dice: 0.9186, Class 2 Dice: 1.0000
3. 0775_1 - dice_class2: 1.0000
   Class 1 Dice: 0.9272, Class 2 Dice: 1.0000
No description has been provided for this image
================================================================================
TOP 3 WORST PERFORMING IMAGES (by dice_class2)
================================================================================
1. 0424_2 - dice_class2: 0.0000
   Class 1 Dice: 0.4869, Class 2 Dice: 0.0000
2. 0061_5 - dice_class2: 0.0000
   Class 1 Dice: 1.0000, Class 2 Dice: 0.0000
3. 0221_5 - dice_class2: 0.0000
   Class 1 Dice: 0.3215, Class 2 Dice: 0.0000
No description has been provided for this image

We see the kernel patterns and the output feature maps for some of the layers¶

In [ ]:
def visualize_conv_kernels(model, layer_name, save_folder="kernel_patterns"):
    """
    Visualize and save convolutional kernel patterns from a specific layer
    """
    Path(save_folder).mkdir(parents=True, exist_ok=True)

    # Get the layer by name
    layer = dict(model.named_modules())[layer_name]

    if not isinstance(layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        print(f"Layer {layer_name} is not a convolutional layer")
        return

    # Get kernel weights [out_channels, in_channels, kernel_h, kernel_w]
    kernels = layer.weight.data.cpu().numpy()

    print(f"Layer: {layer_name}")
    print(f"Kernel shape: {kernels.shape}")

    # Number of output channels (filters)
    num_filters = kernels.shape[0]
    num_input_channels = kernels.shape[1]

    # Create subplots - show first 64 filters max
    max_filters = min(64, num_filters)
    cols = 8
    rows = (max_filters + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(16, 2*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)

    for i in range(max_filters):
        row = i // cols
        col = i % cols

        # For multi-channel input, take mean across input channels
        if num_input_channels > 1:
            kernel_img = np.mean(kernels[i], axis=0)
        else:
            kernel_img = kernels[i, 0]

        # Normalize to [0, 1] for visualization
        kernel_img = (kernel_img - kernel_img.min()) / (kernel_img.max() - kernel_img.min() + 1e-8)

        axes[row, col].imshow(kernel_img, cmap='viridis', interpolation='nearest')
        axes[row, col].set_title(f'Filter {i}', fontsize=8)
        axes[row, col].axis('off')

    # Hide unused subplots
    for i in range(max_filters, rows * cols):
        row = i // cols
        col = i % cols
        axes[row, col].axis('off')

    plt.suptitle(f'Kernels from {layer_name}', fontsize=14)
    plt.tight_layout()
    plt.savefig(f'{save_folder}/{layer_name.replace(".", "_")}_kernels.png', dpi=300, bbox_inches='tight')
    plt.show()

    return kernels


def visualize_feature_maps(model, input_tensor, layer_name, save_folder="feature_maps"):
    """
    Visualize feature maps from a specific layer given an input
    """
    Path(save_folder).mkdir(parents=True, exist_ok=True)

    # Hook to capture feature maps
    feature_maps = {}

    def hook_fn(module, input, output):
        feature_maps[layer_name] = output.detach()

    # Register hook
    layer = dict(model.named_modules())[layer_name]
    handle = layer.register_forward_hook(hook_fn)

    # Forward pass
    model.eval()
    with torch.no_grad():
        _ = model(input_tensor)

    # Remove hook
    handle.remove()

    # Get feature maps
    fmaps = feature_maps[layer_name].cpu().numpy()

    print(f"Feature maps shape: {fmaps.shape}")  # [batch, channels, height, width]

    # Take first sample from batch
    fmaps = fmaps[0]  # [channels, height, width]

    # Show first 64 feature maps
    max_maps = min(64, fmaps.shape[0])
    cols = 8
    rows = (max_maps + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(16, 2*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)

    for i in range(max_maps):
        row = i // cols
        col = i % cols

        fmap = fmaps[i]
        axes[row, col].imshow(fmap, cmap='viridis')
        axes[row, col].set_title(f'Channel {i}', fontsize=8)
        axes[row, col].axis('off')

    # Hide unused subplots
    for i in range(max_maps, rows * cols):
        row = i // cols
        col = i % cols
        axes[row, col].axis('off')

    plt.suptitle(f'Feature Maps from {layer_name}', fontsize=14)
    plt.tight_layout()
    plt.savefig(f'{save_folder}/{layer_name.replace(".", "_")}_fmaps.png', dpi=300, bbox_inches='tight')
    plt.show()

    return fmaps


def analyze_model_patterns(checkpoint_path, sample_input=None):
    """
    Main function to analyze kernel patterns and feature maps
    """
    # Load model
    model = ResUnetPlusPlus(input_ch=3, n_class=3)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()

    # Print all layer names to help identify layers of interest
    print("Available layers:")
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
            print(f"  {name}: {type(module).__name__} {module.weight.shape}")

    # Create sample input if not provided
    if sample_input is None:
        sample_input = torch.randn(1, 3, 256, 256).to(device)

    # Analyze different layers - modify these based on your model architecture
    layers_to_analyze = [
        # First layer (typically captures edges/basic patterns)
        "encoder.0",  # or whatever your first conv layer is named

        # Middle layers (intermediate features)
        "encoder.4",  # adjust based on your model

        # Final layers (high-level features)
        "decoder.0",  # adjust based on your model
        "final_conv",  # or your final output layer
    ]

    # You may need to adjust layer names based on your actual model structure
    # Uncomment and modify the layer names that exist in your model:

    # Example for common layer names:
    # layers_to_analyze = [
    #     "conv1",
    #     "layer1.0.conv1",
    #     "layer2.0.conv1",
    #     "layer3.0.conv1",
    #     "layer4.0.conv1",
    #     "classifier"  # final layer
    # ]

    print("\nAnalyzing kernel patterns...")
    for layer_name in layers_to_analyze:
        try:
            print(f"\n--- Analyzing {layer_name} ---")

            # Visualize kernels
            kernels = visualize_conv_kernels(model, layer_name)

            # Visualize feature maps
            feature_maps = visualize_feature_maps(model, sample_input, layer_name)

        except KeyError:
            print(f"Layer '{layer_name}' not found in model. Skipping...")
            continue
        except Exception as e:
            print(f"Error analyzing layer '{layer_name}': {e}")
            continue


def visualize_conv_kernels_rgb(model, layer_name, save_folder="kernel_patterns"):
    """
    Visualize convolutional kernels as RGB-style images
    """
    Path(save_folder).mkdir(parents=True, exist_ok=True)

    layer = dict(model.named_modules())[layer_name]
    if not isinstance(layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
        print(f"Layer {layer_name} is not a convolutional layer")
        return

    kernels = layer.weight.data.cpu().numpy()
    print(f"Layer: {layer_name}, Kernel shape: {kernels.shape}")

    num_filters = kernels.shape[0]
    num_input_channels = kernels.shape[1]

    # Show up to 32 filters for better visualization
    max_filters = min(32, num_filters)
    cols = 8
    rows = (max_filters + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(16, 2*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)

    for i in range(max_filters):
        row = i // cols
        col = i % cols

        kernel = kernels[i]

        if num_input_channels >= 3:
            # Take first 3 channels as RGB
            kernel_rgb = kernel[:3]  # Shape: [3, H, W]
            kernel_rgb = np.transpose(kernel_rgb, (1, 2, 0))  # [H, W, 3]

            # Normalize each channel independently
            for c in range(3):
                channel = kernel_rgb[:, :, c]
                kernel_rgb[:, :, c] = (channel - channel.min()) / (channel.max() - channel.min() + 1e-8)

        elif num_input_channels == 1:
            # Convert single channel to RGB (grayscale)
            kernel_gray = kernel[0]
            kernel_gray = (kernel_gray - kernel_gray.min()) / (kernel_gray.max() - kernel_gray.min() + 1e-8)
            kernel_rgb = np.stack([kernel_gray] * 3, axis=-1)
        else:
            # For other cases, use mean across input channels
            kernel_mean = np.mean(kernel, axis=0)
            kernel_mean = (kernel_mean - kernel_mean.min()) / (kernel_mean.max() - kernel_mean.min() + 1e-8)
            kernel_rgb = np.stack([kernel_mean] * 3, axis=-1)

        axes[row, col].imshow(kernel_rgb)
        axes[row, col].set_title(f'Filter {i}', fontsize=8)
        axes[row, col].axis('off')

    # Hide unused subplots
    for i in range(max_filters, rows * cols):
        row = i // cols
        col = i % cols
        axes[row, col].axis('off')

    plt.suptitle(f'RGB-style Kernels from {layer_name}', fontsize=14)
    plt.tight_layout()
    plt.savefig(f'{save_folder}/{layer_name.replace(".", "_")}_rgb_kernels.png', dpi=300, bbox_inches='tight')
    plt.show()

    return kernels


def generate_grad_cam(model, input_tensor, target_layer, target_class=None):
    """
    Generate Grad-CAM heatmap for understanding what the network focuses on
    """
    model.eval()
    input_tensor.requires_grad_(True)  # Enable gradients on input

    # Store gradients and feature maps
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        if grad_output[0] is not None:
            gradients.append(grad_output[0].detach())

    def forward_hook(module, input, output):
        activations.append(output.detach())

    try:
        # Register hooks
        target_layer_module = dict(model.named_modules())[target_layer]
        h1 = target_layer_module.register_backward_hook(backward_hook)
        h2 = target_layer_module.register_forward_hook(forward_hook)

        # Forward pass
        output = model(input_tensor)

        # If no target class specified, use the class with highest probability
        if target_class is None:
            target_class = output.argmax(dim=1).item()

        # Backward pass
        model.zero_grad()

        # Get the score for target class (handle different output shapes)
        if len(output.shape) == 4:  # [B, C, H, W] - segmentation output
            class_score = output[0, target_class].mean()
        else:  # [B, C] - classification output
            class_score = output[0, target_class]

        class_score.backward(retain_graph=True)

        # Remove hooks
        h1.remove()
        h2.remove()

        # Check if we got gradients and activations
        if not gradients or not activations:
            print(f"No gradients or activations captured for layer {target_layer}")
            return np.zeros((64, 64)), target_class

        # Generate Grad-CAM
        gradients_tensor = gradients[0][0]  # [C, H, W]
        activations_tensor = activations[0][0]  # [C, H, W]

        # Ensure tensors are on the same device
        device = activations_tensor.device
        gradients_tensor = gradients_tensor.to(device)

        # Calculate weights (global average pooling of gradients)
        weights = torch.mean(gradients_tensor, dim=(1, 2))  # [C]

        # Generate heatmap on the same device
        cam = torch.zeros(activations_tensor.shape[1:], dtype=torch.float32, device=device)
        for i, w in enumerate(weights):
            cam += w * activations_tensor[i, :, :]

        cam = torch.clamp(cam, min=0)  # ReLU
        cam = cam / cam.max() if cam.max() > 0 else cam  # Normalize

        return cam.detach().cpu().numpy(), target_class

    except Exception as e:
        print(f"Error in Grad-CAM for layer {target_layer}: {e}")
        return np.zeros((64, 64)), 0


def visualize_layer_attention(model, input_image, checkpoint_path, save_folder="attention_maps"):
    """
    Visualize what different layers focus on for a given input image
    """
    Path(save_folder).mkdir(parents=True, exist_ok=True)

    # Load model
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()

    # Prepare input
    if isinstance(input_image, str):
        # Load image from path
        img = cv2.imread(input_image)
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256, 256))
        img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0
    else:
        img_tensor = input_image
        img = (input_image.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    img_tensor = img_tensor.to(device)

    # Updated key layers based on your actual ResUNet++ architecture
    key_layers = [
        'c1.c1.0',      # First conv in stem block
        'c2.c1.2',      # First conv in ResNet block c2
        'c3.c1.2',      # First conv in ResNet block c3
        'c4.c1.2',      # First conv in ResNet block c4
        'b1.c1.0',      # First branch in ASPP
        'd1.r1.c1.2',   # First decoder ResNet block
        'd2.r1.c1.2',   # Second decoder ResNet block
        'output'        # Final output layer
    ]

    # Create figure with appropriate number of subplots
    num_plots = len(key_layers) + 1  # +1 for original image
    cols = 3
    rows = (num_plots + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()

    # Show original image
    axes[0].imshow(img)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    print("Generating attention maps...")

    valid_idx = 1
    for layer_name in key_layers:
        try:
            print(f"Processing layer: {layer_name}")

            # Check if layer exists
            if layer_name not in dict(model.named_modules()):
                print(f"Layer {layer_name} not found, skipping...")
                continue

            # Generate Grad-CAM (try class 1 first, then fallback)
            heatmap, target_class = generate_grad_cam(model, img_tensor.clone(), layer_name, target_class=1)

            if heatmap.max() == 0:  # If no activation, try different class
                heatmap, target_class = generate_grad_cam(model, img_tensor.clone(), layer_name, target_class=None)

            if heatmap.max() == 0:  # Still no activation, create placeholder
                heatmap = np.random.rand(64, 64) * 0.1
                print(f"No significant activation in {layer_name}")

            # Resize heatmap to match input image size
            if heatmap.shape != (256, 256):
                heatmap = cv2.resize(heatmap, (256, 256))

            # Create colored heatmap
            heatmap_normalized = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
            heatmap_colored = plt.cm.jet(heatmap_normalized)[:, :, :3]

            # Overlay on original image
            overlay = 0.6 * (img/255.0) + 0.4 * heatmap_colored
            overlay = np.clip(overlay, 0, 1)

            if valid_idx < len(axes):
                axes[valid_idx].imshow(overlay)
                axes[valid_idx].set_title(f'{layer_name}\n(Class: {target_class})', fontsize=10)
                axes[valid_idx].axis('off')

                # Save individual attention map
                plt.imsave(f'{save_folder}/attention_{layer_name.replace(".", "_")}.png',
                          overlay, dpi=300)

                valid_idx += 1

        except Exception as e:
            print(f"Error processing layer {layer_name}: {e}")
            continue

    # Hide unused subplots
    for i in range(valid_idx, len(axes)):
        axes[i].axis('off')

    plt.suptitle('Layer-wise Attention Maps (Grad-CAM)', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'{save_folder}/combined_attention_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

    print(f"Attention maps saved to: {save_folder}")
    return True


def visualize_feature_evolution(model, input_image, checkpoint_path, save_folder="feature_evolution"):
    """
    Show how features evolve through the network layers
    """
    Path(save_folder).mkdir(parents=True, exist_ok=True)

    # Load model
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()

    # Prepare input
    if isinstance(input_image, str):
        img = cv2.imread(input_image)
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (256, 256))
        img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0) / 255.0
    else:
        img_tensor = input_image
        img = (input_image.squeeze().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)

    img_tensor = img_tensor.to(device)

    # Layers to track feature evolution
    evolution_layers = [
        'c1',           # Stem block
        'c2',           # First encoder
        'c3',           # Second encoder
        'c4',           # Third encoder
        'b1',           # Bridge
        'd1',           # First decoder
        'd2',           # Second decoder
        'd3',           # Third decoder
        'output'        # Final output
    ]

    # Hook to capture intermediate outputs
    intermediate_outputs = {}
    hooks = []

    def make_hook(name):
        def hook(module, input, output):
            intermediate_outputs[name] = output.detach()
        return hook

    # Register hooks
    for layer_name in evolution_layers:
        try:
            layer = dict(model.named_modules())[layer_name]
            hook = layer.register_forward_hook(make_hook(layer_name))
            hooks.append(hook)
        except KeyError:
            print(f"Layer {layer_name} not found")

    # Forward pass
    with torch.no_grad():
        output = model(img_tensor)

    # Remove hooks
    for hook in hooks:
        hook.remove()

    # Visualize evolution
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    axes = axes.flatten()

    # Show original image
    axes[0].imshow(img)
    axes[0].set_title('Input Image')
    axes[0].axis('off')

    for idx, layer_name in enumerate(evolution_layers):
        if layer_name in intermediate_outputs and idx + 1 < len(axes):
            feature_map = intermediate_outputs[layer_name][0]  # Remove batch dimension

            if len(feature_map.shape) == 3:  # [C, H, W]
                # Take mean across channels for visualization
                feature_viz = torch.mean(feature_map, dim=0).cpu().numpy()
            else:  # Already 2D
                feature_viz = feature_map.cpu().numpy()

            # Normalize
            feature_viz = (feature_viz - feature_viz.min()) / (feature_viz.max() - feature_viz.min() + 1e-8)

            axes[idx + 1].imshow(feature_viz, cmap='viridis')
            axes[idx + 1].set_title(f'{layer_name}\nShape: {feature_map.shape}')
            axes[idx + 1].axis('off')

    plt.suptitle('Feature Evolution Through Network', fontsize=16)
    plt.tight_layout()
    plt.savefig(f'{save_folder}/feature_evolution.png', dpi=300, bbox_inches='tight')
    plt.show()

    return intermediate_outputs


def comprehensive_network_analysis(checkpoint_path, sample_image_path):
    """
    Comprehensive analysis of what the network sees and understands
    """
    print("="*80)
    print("COMPREHENSIVE NETWORK ANALYSIS")
    print("="*80)

    # Load model
    model = ResUnetPlusPlus(input_ch=3, n_class=3)
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    model.eval()

    # 1. RGB-style kernel visualization
    print("\n1. Analyzing RGB-style filter patterns...")
    key_conv_layers = ['c1.c1.0', 'c2.c1.2', 'c4.c1.2', 'output']

    for layer in key_conv_layers:
        try:
            visualize_conv_kernels_rgb(model, layer)
        except:
            print(f"Could not analyze layer {layer}")

    # 2. Layer attention analysis
    print("\n2. Analyzing layer-wise attention...")
    visualize_layer_attention(model, sample_image_path, checkpoint_path)

    # 3. Feature evolution
    print("\n3. Analyzing feature evolution...")
    visualize_feature_evolution(model, sample_image_path, checkpoint_path)

    print("\nAnalysis complete! Check the generated folders for results.")


# Usage examples:
comprehensive_network_analysis("multiclass_model_998.pt", "./test_images/0048_3.png")
================================================================================
COMPREHENSIVE NETWORK ANALYSIS
================================================================================

1. Analyzing RGB-style filter patterns...
Layer: c1.c1.0, Kernel shape: (16, 3, 3, 3)
No description has been provided for this image
Layer: c2.c1.2, Kernel shape: (32, 16, 3, 3)
No description has been provided for this image
Layer: c4.c1.2, Kernel shape: (128, 64, 3, 3)
No description has been provided for this image
Layer: output, Kernel shape: (3, 16, 1, 1)
No description has been provided for this image
2. Analyzing layer-wise attention...
Generating attention maps...
Processing layer: c1.c1.0
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py:1864: FutureWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Processing layer: c2.c1.2
Processing layer: c3.c1.2
Processing layer: c4.c1.2
Processing layer: b1.c1.0
Processing layer: d1.r1.c1.2
Processing layer: d2.r1.c1.2
Processing layer: output
No description has been provided for this image
Attention maps saved to: attention_maps

3. Analyzing feature evolution...
No description has been provided for this image
Analysis complete! Check the generated folders for results.
In [ ]:
# # Or individual analyses:
# model = ResUnetPlusPlus(input_ch=3, n_class=3)
# checkpoint = torch.load("multiclass_model_998.pt")
# model.load_state_dict(checkpoint['state_dict'])
# visualize_conv_kernels_rgb(model, "c1.c1.0")  # First layer RGB kernels
In [ ]: