In [ ]:
!pip install torch torchvision
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.6.0+cu101)
Requirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (0.7.0+cu101)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.18.5)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.16.0)
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (7.0.0)
In [ ]:
# https://gist.github.com/AFAgarap/4f8a8d8edf352271fa06d85ba0361f26
In [1]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
In [2]:
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
In [3]:
batch_size = 512
epochs = 200
learning_rate = 1e-3
In [4]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /root/torch_datasets/MNIST/raw/train-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /root/torch_datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw/t10k-images-idx3-ubyte.gz

Extracting /root/torch_datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/torch_datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /root/torch_datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/torch_datasets/MNIST/raw
Processing...
Done!
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:469: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

Without Dropouts

In [5]:
class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)
        return reconstructed
In [6]:
#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ",device)

# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AE(input_shape=784)

# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# mean-squared error loss
criterion = nn.MSELoss()
device =  cuda
In [7]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=4
)
In [ ]:
losses = []
for epoch in range(epochs):
    loss = 0
    for batch_features, _ in train_loader:
        # reshape mini-batch data to [N, 784] matrix
        # load it to the active device
        batch_features = batch_features.view(-1, 784)
        
        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        optimizer.zero_grad()
        
        # compute reconstructions
        outputs = model(batch_features)
        
        # compute training reconstruction loss
        train_loss = criterion(outputs, batch_features)
        
        # compute accumulated gradients
        train_loss.backward()
        
        # perform parameter update based on current gradients
        optimizer.step()
        
        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()
    
    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    losses.append(loss)
    print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
epoch : 1/200, recon loss = 0.06415338
epoch : 2/200, recon loss = 0.04154352
epoch : 3/200, recon loss = 0.03541024
epoch : 4/200, recon loss = 0.02795628
epoch : 5/200, recon loss = 0.02329513
epoch : 6/200, recon loss = 0.02071499
epoch : 7/200, recon loss = 0.01866129
epoch : 8/200, recon loss = 0.01691499
epoch : 9/200, recon loss = 0.01557805
epoch : 10/200, recon loss = 0.01403124
epoch : 11/200, recon loss = 0.01278244
epoch : 12/200, recon loss = 0.01190003
epoch : 13/200, recon loss = 0.01126922
epoch : 14/200, recon loss = 0.01070535
epoch : 15/200, recon loss = 0.01022582
epoch : 16/200, recon loss = 0.00979880
epoch : 17/200, recon loss = 0.00939511
epoch : 18/200, recon loss = 0.00899579
epoch : 19/200, recon loss = 0.00866037
epoch : 20/200, recon loss = 0.00837372
epoch : 21/200, recon loss = 0.00811422
epoch : 22/200, recon loss = 0.00787735
epoch : 23/200, recon loss = 0.00767605
epoch : 24/200, recon loss = 0.00750453
epoch : 25/200, recon loss = 0.00734724
epoch : 26/200, recon loss = 0.00719759
epoch : 27/200, recon loss = 0.00704476
epoch : 28/200, recon loss = 0.00688683
epoch : 29/200, recon loss = 0.00674700
epoch : 30/200, recon loss = 0.00660915
epoch : 31/200, recon loss = 0.00647049
epoch : 32/200, recon loss = 0.00633768
epoch : 33/200, recon loss = 0.00620463
epoch : 34/200, recon loss = 0.00608786
epoch : 35/200, recon loss = 0.00597996
epoch : 36/200, recon loss = 0.00586265
epoch : 37/200, recon loss = 0.00577069
epoch : 38/200, recon loss = 0.00567200
epoch : 39/200, recon loss = 0.00558520
epoch : 40/200, recon loss = 0.00550607
epoch : 41/200, recon loss = 0.00543390
epoch : 42/200, recon loss = 0.00534967
epoch : 43/200, recon loss = 0.00528255
epoch : 44/200, recon loss = 0.00521349
epoch : 45/200, recon loss = 0.00514333
epoch : 46/200, recon loss = 0.00508021
epoch : 47/200, recon loss = 0.00502300
epoch : 48/200, recon loss = 0.00495950
epoch : 49/200, recon loss = 0.00490377
epoch : 50/200, recon loss = 0.00485274
epoch : 51/200, recon loss = 0.00480395
epoch : 52/200, recon loss = 0.00475865
epoch : 53/200, recon loss = 0.00471667
epoch : 54/200, recon loss = 0.00466510
epoch : 55/200, recon loss = 0.00463050
epoch : 56/200, recon loss = 0.00458789
epoch : 57/200, recon loss = 0.00455083
epoch : 58/200, recon loss = 0.00451195
epoch : 59/200, recon loss = 0.00447663
epoch : 60/200, recon loss = 0.00444807
epoch : 61/200, recon loss = 0.00440377
epoch : 62/200, recon loss = 0.00437592
epoch : 63/200, recon loss = 0.00434485
epoch : 64/200, recon loss = 0.00431782
epoch : 65/200, recon loss = 0.00427906
epoch : 66/200, recon loss = 0.00424610
epoch : 67/200, recon loss = 0.00420826
epoch : 68/200, recon loss = 0.00418040
epoch : 69/200, recon loss = 0.00414816
epoch : 70/200, recon loss = 0.00411809
epoch : 71/200, recon loss = 0.00408783
epoch : 72/200, recon loss = 0.00406279
epoch : 73/200, recon loss = 0.00403382
epoch : 74/200, recon loss = 0.00401184
epoch : 75/200, recon loss = 0.00398444
epoch : 76/200, recon loss = 0.00396081
epoch : 77/200, recon loss = 0.00394266
epoch : 78/200, recon loss = 0.00391589
epoch : 79/200, recon loss = 0.00389317
epoch : 80/200, recon loss = 0.00387960
epoch : 81/200, recon loss = 0.00384918
epoch : 82/200, recon loss = 0.00382716
epoch : 83/200, recon loss = 0.00381153
epoch : 84/200, recon loss = 0.00379020
epoch : 85/200, recon loss = 0.00377023
epoch : 86/200, recon loss = 0.00374067
epoch : 87/200, recon loss = 0.00372741
epoch : 88/200, recon loss = 0.00370483
epoch : 89/200, recon loss = 0.00367854
epoch : 90/200, recon loss = 0.00366501
epoch : 91/200, recon loss = 0.00363971
epoch : 92/200, recon loss = 0.00362428
epoch : 93/200, recon loss = 0.00360167
epoch : 94/200, recon loss = 0.00358546
epoch : 95/200, recon loss = 0.00357241
epoch : 96/200, recon loss = 0.00355187
epoch : 97/200, recon loss = 0.00353387
epoch : 98/200, recon loss = 0.00351539
epoch : 99/200, recon loss = 0.00349951
epoch : 100/200, recon loss = 0.00348720
epoch : 101/200, recon loss = 0.00346995
epoch : 102/200, recon loss = 0.00345725
epoch : 103/200, recon loss = 0.00343904
epoch : 104/200, recon loss = 0.00342354
epoch : 105/200, recon loss = 0.00340937
epoch : 106/200, recon loss = 0.00339752
epoch : 107/200, recon loss = 0.00338149
epoch : 108/200, recon loss = 0.00336982
epoch : 109/200, recon loss = 0.00335678
epoch : 110/200, recon loss = 0.00334347
epoch : 111/200, recon loss = 0.00333428
epoch : 112/200, recon loss = 0.00331881
epoch : 113/200, recon loss = 0.00330533
epoch : 114/200, recon loss = 0.00329794
epoch : 115/200, recon loss = 0.00328625
epoch : 116/200, recon loss = 0.00327241
epoch : 117/200, recon loss = 0.00326193
epoch : 118/200, recon loss = 0.00324943
epoch : 119/200, recon loss = 0.00324483
epoch : 120/200, recon loss = 0.00322642
epoch : 121/200, recon loss = 0.00321995
epoch : 122/200, recon loss = 0.00320030
epoch : 123/200, recon loss = 0.00319765
epoch : 124/200, recon loss = 0.00318341
epoch : 125/200, recon loss = 0.00317950
epoch : 126/200, recon loss = 0.00317259
epoch : 127/200, recon loss = 0.00316137
epoch : 128/200, recon loss = 0.00314789
epoch : 129/200, recon loss = 0.00314072
epoch : 130/200, recon loss = 0.00313256
epoch : 131/200, recon loss = 0.00312197
epoch : 132/200, recon loss = 0.00311291
epoch : 133/200, recon loss = 0.00310338
epoch : 134/200, recon loss = 0.00309822
epoch : 135/200, recon loss = 0.00308383
epoch : 136/200, recon loss = 0.00308044
epoch : 137/200, recon loss = 0.00306875
epoch : 138/200, recon loss = 0.00306406
epoch : 139/200, recon loss = 0.00305265
epoch : 140/200, recon loss = 0.00304308
epoch : 141/200, recon loss = 0.00303854
epoch : 142/200, recon loss = 0.00302832
epoch : 143/200, recon loss = 0.00302330
epoch : 144/200, recon loss = 0.00301338
epoch : 145/200, recon loss = 0.00300301
epoch : 146/200, recon loss = 0.00300307
epoch : 147/200, recon loss = 0.00298918
epoch : 148/200, recon loss = 0.00298198
epoch : 149/200, recon loss = 0.00297517
epoch : 150/200, recon loss = 0.00296499
epoch : 151/200, recon loss = 0.00295828
epoch : 152/200, recon loss = 0.00295425
epoch : 153/200, recon loss = 0.00294349
epoch : 154/200, recon loss = 0.00293354
epoch : 155/200, recon loss = 0.00292972
epoch : 156/200, recon loss = 0.00292370
epoch : 157/200, recon loss = 0.00291145
epoch : 158/200, recon loss = 0.00290740
epoch : 159/200, recon loss = 0.00289831
epoch : 160/200, recon loss = 0.00288795
epoch : 161/200, recon loss = 0.00289152
epoch : 162/200, recon loss = 0.00287760
epoch : 163/200, recon loss = 0.00286812
epoch : 164/200, recon loss = 0.00286277
epoch : 165/200, recon loss = 0.00285185
epoch : 166/200, recon loss = 0.00285136
epoch : 167/200, recon loss = 0.00284419
epoch : 168/200, recon loss = 0.00283532
epoch : 169/200, recon loss = 0.00283528
epoch : 170/200, recon loss = 0.00282344
epoch : 171/200, recon loss = 0.00281540
epoch : 172/200, recon loss = 0.00280948
epoch : 173/200, recon loss = 0.00280846
epoch : 174/200, recon loss = 0.00280017
epoch : 175/200, recon loss = 0.00279204
epoch : 176/200, recon loss = 0.00278541
epoch : 177/200, recon loss = 0.00278195
epoch : 178/200, recon loss = 0.00277612
epoch : 179/200, recon loss = 0.00277704
epoch : 180/200, recon loss = 0.00276329
epoch : 181/200, recon loss = 0.00276216
epoch : 182/200, recon loss = 0.00275160
epoch : 183/200, recon loss = 0.00274933
epoch : 184/200, recon loss = 0.00274087
epoch : 185/200, recon loss = 0.00273846
epoch : 186/200, recon loss = 0.00272918
epoch : 187/200, recon loss = 0.00273322
epoch : 188/200, recon loss = 0.00272122
epoch : 189/200, recon loss = 0.00271904
epoch : 190/200, recon loss = 0.00271407
epoch : 191/200, recon loss = 0.00270791
epoch : 192/200, recon loss = 0.00270369
epoch : 193/200, recon loss = 0.00270009
epoch : 194/200, recon loss = 0.00269427
epoch : 195/200, recon loss = 0.00268801
epoch : 196/200, recon loss = 0.00268824
epoch : 197/200, recon loss = 0.00268362
epoch : 198/200, recon loss = 0.00268294
epoch : 199/200, recon loss = 0.00267602
epoch : 200/200, recon loss = 0.00267046
In [44]:
#model = AE(nn.module)
model.load_state_dict(torch.load("Mnist_without_dropout.pt"))
model.eval()
Out[44]:
AEDropouts(
  (encoder_hidden_layer): Linear(in_features=784, out_features=128, bias=True)
  (drop_hl1): Dropout(p=0.5, inplace=False)
  (encoder_output_layer): Linear(in_features=128, out_features=128, bias=True)
  (decoder_hidden_layer): Linear(in_features=128, out_features=128, bias=True)
  (drop_hl2): Dropout(p=0.5, inplace=False)
  (decoder_output_layer): Linear(in_features=128, out_features=784, bias=True)
)
In [47]:
test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10, shuffle=False
)

test_examples = None

with torch.no_grad():
    for batch_features in test_loader:
        batch_features = batch_features[0]
        test_examples = batch_features.view(-1, 784)
        reconstruction = model(test_examples)
        break

Compare the results

In [48]:
with torch.no_grad():
    number = 10
    plt.figure(figsize=(20, 6))
    for index in range(number):
        # display original
        ax = plt.subplot(3, number, index + 1)
        plt.imshow(test_examples[index].numpy().reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(3, number, index + 1 + number)
        plt.imshow(reconstruction[index].numpy().reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display difference
        ax = plt.subplot(3, number, index + 2*number + 1)
        original = test_examples[index].numpy().reshape(28, 28)
        generated = reconstruction[index].numpy().reshape(28, 28)
        diff = original - generated
        plt.imshow(diff,cmap="viridis")
        #plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # do xor and display difference
        #ax = plt.subplot(4, number, index + 3*number + 1)
        #original = test_examples[index].numpy().reshape(28, 28)
        #generated = reconstruction[index].numpy().reshape(28, 28)
        # thresh_or = original > 0
        # original[thresh_or] = 1

        # thresh_gen = generated > 0
        # generated[thresh_gen] = 1
        # diff_xor = original - generated
        # plt.imshow(diff,cmap="viridis")
    plt.show()

Loss graph

In [ ]:
plt.plot(losses, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Losses');

Check all the parameters

In [ ]:
for item in model.parameters():
  print(item.shape)
torch.Size([128, 784])
torch.Size([128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([784, 128])
torch.Size([784])
In [ ]:
total_img = 128
num_pr = 16 # number per row

counter = 1
plt.figure(figsize=(30,30))
plt.axis('off')

for item in model.parameters():
  
  for item_1 in item:
    img_weights = torch.reshape(item_1, (28, 28))
    #print(item_1.shape)
    plt.subplot(num_pr, num_pr, counter)
    plt.axis('off')
    plt.imshow(img_weights.detach().numpy(),cmap="gray")
    counter += 1
    #plt.figure(figsize=(2,2))
    #plt.imshow(img_weights.detach().numpy(),cmap="gray")
    #plt.show()
    #print(item_1.reshape((28,28)))
    #break
  plt.show()
  break
In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k == 0:
    k = 1
    continue
  #print(item.shape)
  img_weights = torch.reshape(item, (8, 16))
  plt.axis('off')
  plt.imshow(img_weights.detach().numpy(),cmap="gray")
  plt.show()
  break

Weights learned in the bottleneck

In [ ]:
total_img = 128
num_pr = 16 # number per row

counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 1:
    k += 1
    continue
  #print(item.shape)
  for item_1 in item:
    try:
      #print("shape = ",item_1.shape[0])
      
      img_weights = torch.reshape(item_1, (8, 16))
      
      plt.subplot(num_pr, num_pr, counter)
      plt.axis('off')
      plt.imshow(img_weights.detach().numpy(),cmap="gray")
      counter += 1
      
    except:
      pass
  plt.show()
  break
In [ ]:
item.shape
Out[ ]:
torch.Size([128, 128])

Biases of the bottleneck encoder

In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 2:
    k += 1
    continue
  print(item.shape)
  img_weights = torch.reshape(item_1, (8, 16))
  plt.imshow(img_weights.detach().numpy(),cmap="gray")
  plt.show()
  break
torch.Size([128])

Weights of the bottleneck decoder

In [ ]:
total_img = 128
num_pr = 16 # number per row

counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 3:
    k += 1
    continue
  print(item.shape)
  
  for item_1 in item:
    try:
      #print("shape = ",item_1.shape[0])
      
      img_weights = torch.reshape(item_1, (8, 16))
      
      plt.subplot(num_pr, num_pr, counter)
      plt.axis('off')
      plt.imshow(img_weights.detach().numpy(),cmap="gray")
      counter += 1
      
    except:
      pass
  plt.show()
  
  break
torch.Size([128, 128])

Bias of the bottleneck decoder

In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 4:
    k += 1
    continue
  print(item.shape)
  img_weights = torch.reshape(item_1, (8, 16))
  plt.imshow(img_weights.detach().numpy(),cmap="gray")
  plt.show()  
  break
torch.Size([128])
In [ ]:
for item in model.parameters():
  print(item.shape)
torch.Size([128, 784])
torch.Size([128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([784, 128])
torch.Size([784])

Decoder weights

In [ ]:
total_img = 784
num_pr = 28 # number per row

counter = 1
plt.figure(figsize=(64*2,32*2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 5:
    k += 1
    continue
  print(item.shape)
  for item_1 in item:
    #print("shape = ",item_1.shape)
    img_weights = torch.reshape(item_1, (8, 16))
    plt.subplot(num_pr, num_pr, counter)
    plt.axis('off')
    plt.imshow(img_weights.detach().numpy(),cmap="gray")
    counter += 1
  plt.show()
  break
torch.Size([784, 128])

Decoder biases

In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
  if k <= 6:
    k += 1
    continue
  print(item.shape)
  img_weights = torch.reshape(item, (28, 28))
  plt.axis('off')
  plt.imshow(img_weights.detach().numpy(),cmap="gray")

  break
torch.Size([784])
In [ ]:
torch.save(model.state_dict(), "Mnist_without_dropout.pt")
In [ ]:
 
In [ ]:
 
In [ ]:
 

With Dropouts

In [24]:
class AEDropouts(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )
        self.drop_hl1 = nn.Dropout(0.5)
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )
        self.drop_hl2 = nn.Dropout(0.5)
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(self.drop_hl1(activation))
        
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(self.drop_hl1(activation))
        
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)
        return reconstructed
In [25]:
#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ",device)

# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AEDropouts(input_shape=784)

# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# mean-squared error loss
criterion = nn.MSELoss()
device =  cuda
In [26]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=4
)
In [ ]:
# cool https://stackoverflow.com/questions/53419474/using-dropout-in-pytorch-nn-dropout-vs-f-dropout
In [ ]:
losses = []
for epoch in range(epochs):
    loss = 0
    for batch_features, _ in train_loader:
        # reshape mini-batch data to [N, 784] matrix
        # load it to the active device
        batch_features = batch_features.view(-1, 784)
        
        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        optimizer.zero_grad()
        
        # compute reconstructions
        outputs = model(batch_features)
        
        # compute training reconstruction loss
        train_loss = criterion(outputs, batch_features)
        
        # compute accumulated gradients
        train_loss.backward()
        
        # perform parameter update based on current gradients
        optimizer.step()
        
        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()
    
    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    losses.append(loss)
    print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
epoch : 1/200, recon loss = 0.06787991
epoch : 2/200, recon loss = 0.04734913
epoch : 3/200, recon loss = 0.04276034
epoch : 4/200, recon loss = 0.04104461
epoch : 5/200, recon loss = 0.03985566
epoch : 6/200, recon loss = 0.03901467
epoch : 7/200, recon loss = 0.03830852
epoch : 8/200, recon loss = 0.03779165
epoch : 9/200, recon loss = 0.03729005
epoch : 10/200, recon loss = 0.03686484
epoch : 11/200, recon loss = 0.03651264
epoch : 12/200, recon loss = 0.03613999
epoch : 13/200, recon loss = 0.03591668
epoch : 14/200, recon loss = 0.03566672
epoch : 15/200, recon loss = 0.03557408
epoch : 16/200, recon loss = 0.03536771
epoch : 17/200, recon loss = 0.03519088
epoch : 18/200, recon loss = 0.03508433
epoch : 19/200, recon loss = 0.03498988
epoch : 20/200, recon loss = 0.03485433
epoch : 21/200, recon loss = 0.03480767
epoch : 22/200, recon loss = 0.03472511
epoch : 23/200, recon loss = 0.03464253
epoch : 24/200, recon loss = 0.03453601
epoch : 25/200, recon loss = 0.03447530
epoch : 26/200, recon loss = 0.03440757
epoch : 27/200, recon loss = 0.03433920
epoch : 28/200, recon loss = 0.03436335
epoch : 29/200, recon loss = 0.03425738
epoch : 30/200, recon loss = 0.03419059
epoch : 31/200, recon loss = 0.03419725
epoch : 32/200, recon loss = 0.03408757
epoch : 33/200, recon loss = 0.03407245
epoch : 34/200, recon loss = 0.03397934
epoch : 35/200, recon loss = 0.03397765
epoch : 36/200, recon loss = 0.03396586
epoch : 37/200, recon loss = 0.03392269
epoch : 38/200, recon loss = 0.03393081
epoch : 39/200, recon loss = 0.03386310
epoch : 40/200, recon loss = 0.03379047
epoch : 41/200, recon loss = 0.03379472
epoch : 42/200, recon loss = 0.03381384
epoch : 43/200, recon loss = 0.03371069
epoch : 44/200, recon loss = 0.03373424
epoch : 45/200, recon loss = 0.03369320
epoch : 46/200, recon loss = 0.03366347
epoch : 47/200, recon loss = 0.03368948
epoch : 48/200, recon loss = 0.03362246
epoch : 49/200, recon loss = 0.03360149
epoch : 50/200, recon loss = 0.03355088
epoch : 51/200, recon loss = 0.03357455
epoch : 52/200, recon loss = 0.03354184
epoch : 53/200, recon loss = 0.03350033
epoch : 54/200, recon loss = 0.03348105
epoch : 55/200, recon loss = 0.03348197
epoch : 56/200, recon loss = 0.03345409
epoch : 57/200, recon loss = 0.03345971
epoch : 58/200, recon loss = 0.03343131
epoch : 59/200, recon loss = 0.03341563
epoch : 60/200, recon loss = 0.03341407
epoch : 61/200, recon loss = 0.03343615
epoch : 62/200, recon loss = 0.03333437
epoch : 63/200, recon loss = 0.03339827
epoch : 64/200, recon loss = 0.03337417
epoch : 65/200, recon loss = 0.03331758
epoch : 66/200, recon loss = 0.03331548
epoch : 67/200, recon loss = 0.03327017
epoch : 68/200, recon loss = 0.03328061
epoch : 69/200, recon loss = 0.03321650
epoch : 70/200, recon loss = 0.03320600
epoch : 71/200, recon loss = 0.03321669
epoch : 72/200, recon loss = 0.03321582
epoch : 73/200, recon loss = 0.03315486
epoch : 74/200, recon loss = 0.03315523
epoch : 75/200, recon loss = 0.03309210
epoch : 76/200, recon loss = 0.03310189
epoch : 77/200, recon loss = 0.03312915
epoch : 78/200, recon loss = 0.03312182
epoch : 79/200, recon loss = 0.03307657
epoch : 80/200, recon loss = 0.03309655
epoch : 81/200, recon loss = 0.03308805
epoch : 82/200, recon loss = 0.03306603
epoch : 83/200, recon loss = 0.03306362
epoch : 84/200, recon loss = 0.03305515
epoch : 85/200, recon loss = 0.03303161
epoch : 86/200, recon loss = 0.03299602
epoch : 87/200, recon loss = 0.03308810
epoch : 88/200, recon loss = 0.03295467
epoch : 89/200, recon loss = 0.03297906
epoch : 90/200, recon loss = 0.03301324
epoch : 91/200, recon loss = 0.03296737
epoch : 92/200, recon loss = 0.03293894
epoch : 93/200, recon loss = 0.03298372
epoch : 94/200, recon loss = 0.03295693
epoch : 95/200, recon loss = 0.03294547
epoch : 96/200, recon loss = 0.03295382
epoch : 97/200, recon loss = 0.03292201
epoch : 98/200, recon loss = 0.03283836
epoch : 99/200, recon loss = 0.03286240
epoch : 100/200, recon loss = 0.03283152
epoch : 101/200, recon loss = 0.03282903
epoch : 102/200, recon loss = 0.03282569
epoch : 103/200, recon loss = 0.03283098
epoch : 104/200, recon loss = 0.03280598
epoch : 105/200, recon loss = 0.03277236
epoch : 106/200, recon loss = 0.03284918
epoch : 107/200, recon loss = 0.03279235
epoch : 108/200, recon loss = 0.03276818
epoch : 109/200, recon loss = 0.03278851
epoch : 110/200, recon loss = 0.03282598
epoch : 111/200, recon loss = 0.03275116
epoch : 112/200, recon loss = 0.03272804
epoch : 113/200, recon loss = 0.03276458
epoch : 114/200, recon loss = 0.03268726
epoch : 115/200, recon loss = 0.03269993
epoch : 116/200, recon loss = 0.03278802
epoch : 117/200, recon loss = 0.03273899
epoch : 118/200, recon loss = 0.03265529
epoch : 119/200, recon loss = 0.03270122
epoch : 120/200, recon loss = 0.03269600
epoch : 121/200, recon loss = 0.03269528
epoch : 122/200, recon loss = 0.03266993
epoch : 123/200, recon loss = 0.03269863
epoch : 124/200, recon loss = 0.03269689
epoch : 125/200, recon loss = 0.03262891
epoch : 126/200, recon loss = 0.03262957
epoch : 127/200, recon loss = 0.03261413
epoch : 128/200, recon loss = 0.03265882
epoch : 129/200, recon loss = 0.03263096
epoch : 130/200, recon loss = 0.03263403
epoch : 131/200, recon loss = 0.03259492
epoch : 132/200, recon loss = 0.03262198
epoch : 133/200, recon loss = 0.03261854
epoch : 134/200, recon loss = 0.03259024
epoch : 135/200, recon loss = 0.03257424
epoch : 136/200, recon loss = 0.03255695
epoch : 137/200, recon loss = 0.03258559
epoch : 138/200, recon loss = 0.03250474
epoch : 139/200, recon loss = 0.03253296
epoch : 140/200, recon loss = 0.03253258
epoch : 141/200, recon loss = 0.03254005
epoch : 142/200, recon loss = 0.03251991
epoch : 143/200, recon loss = 0.03253502
epoch : 144/200, recon loss = 0.03253127
epoch : 145/200, recon loss = 0.03247309
epoch : 146/200, recon loss = 0.03254236
epoch : 147/200, recon loss = 0.03252508
epoch : 148/200, recon loss = 0.03250609
epoch : 149/200, recon loss = 0.03256159
epoch : 150/200, recon loss = 0.03250008
epoch : 151/200, recon loss = 0.03247455
epoch : 152/200, recon loss = 0.03245800
epoch : 153/200, recon loss = 0.03250400
epoch : 154/200, recon loss = 0.03244282
epoch : 155/200, recon loss = 0.03244042
epoch : 156/200, recon loss = 0.03246173
epoch : 157/200, recon loss = 0.03247826
epoch : 158/200, recon loss = 0.03244982
epoch : 159/200, recon loss = 0.03244960
epoch : 160/200, recon loss = 0.03243473
epoch : 161/200, recon loss = 0.03241576
epoch : 162/200, recon loss = 0.03242218
epoch : 163/200, recon loss = 0.03244574
epoch : 164/200, recon loss = 0.03239550
epoch : 165/200, recon loss = 0.03240758
epoch : 166/200, recon loss = 0.03235424
epoch : 167/200, recon loss = 0.03234117
epoch : 168/200, recon loss = 0.03238857
epoch : 169/200, recon loss = 0.03236992
epoch : 170/200, recon loss = 0.03226839
epoch : 171/200, recon loss = 0.03229935
epoch : 172/200, recon loss = 0.03229281
epoch : 173/200, recon loss = 0.03231021
epoch : 174/200, recon loss = 0.03229899
epoch : 175/200, recon loss = 0.03229971
epoch : 176/200, recon loss = 0.03229929
epoch : 177/200, recon loss = 0.03229438
epoch : 178/200, recon loss = 0.03227556
epoch : 179/200, recon loss = 0.03223644
epoch : 180/200, recon loss = 0.03225725
epoch : 181/200, recon loss = 0.03224995
epoch : 182/200, recon loss = 0.03226307
epoch : 183/200, recon loss = 0.03225984
epoch : 184/200, recon loss = 0.03223259
epoch : 185/200, recon loss = 0.03223612
epoch : 186/200, recon loss = 0.03222683
epoch : 187/200, recon loss = 0.03219394
epoch : 188/200, recon loss = 0.03220955
epoch : 189/200, recon loss = 0.03222084
epoch : 190/200, recon loss = 0.03221722
epoch : 191/200, recon loss = 0.03223931
epoch : 192/200, recon loss = 0.03221014
epoch : 193/200, recon loss = 0.03218750
epoch : 194/200, recon loss = 0.03213995
epoch : 195/200, recon loss = 0.03222483
epoch : 196/200, recon loss = 0.03212500
epoch : 197/200, recon loss = 0.03218815
epoch : 198/200, recon loss = 0.03215718
epoch : 199/200, recon loss = 0.03213737
epoch : 200/200, recon loss = 0.03221352
In [49]:
model.load_state_dict(torch.load("Mnist_with_dropout.pt"))
model.eval()
Out[49]:
AEDropouts(
  (encoder_hidden_layer): Linear(in_features=784, out_features=128, bias=True)
  (drop_hl1): Dropout(p=0.5, inplace=False)
  (encoder_output_layer): Linear(in_features=128, out_features=128, bias=True)
  (decoder_hidden_layer): Linear(in_features=128, out_features=128, bias=True)
  (drop_hl2): Dropout(p=0.5, inplace=False)
  (decoder_output_layer): Linear(in_features=128, out_features=784, bias=True)
)
In [50]:
test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10, shuffle=False
)

test_examples = None

with torch.no_grad():
    for batch_features in test_loader:
        batch_features = batch_features[0]
        test_examples = batch_features.view(-1, 784)
        reconstruction = model(test_examples)
        break
In [51]:
with torch.no_grad():
    number = 10
    plt.figure(figsize=(20, 6))
    for index in range(number):
        # display original
        ax = plt.subplot(3, number, index + 1)
        plt.imshow(test_examples[index].numpy().reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display reconstruction
        ax = plt.subplot(3, number, index + 1 + number)
        plt.imshow(reconstruction[index].numpy().reshape(28, 28))
        plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # display difference
        ax = plt.subplot(3, number, index + 2*number + 1)
        original = test_examples[index].numpy().reshape(28, 28)
        generated = reconstruction[index].numpy().reshape(28, 28)
        diff = original - generated
        plt.imshow(diff,cmap="viridis")
        #plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()
In [ ]:
plt.plot(losses, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Losses');
In [ ]:
for item in model.parameters():
  print(item.shape)
torch.Size([128, 784])
torch.Size([128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([784, 128])
torch.Size([784])
In [ ]:
total_img = 128
num_pr = 16 # number per row

counter = 1
plt.figure(figsize=(30,30))
plt.axis('off')

for item in model.parameters():
  
  for item_1 in item:
    img_weights = torch.reshape(item_1, (28, 28))
    #print(item_1.shape)
    plt.subplot(num_pr, num_pr, counter)
    plt.axis('off')
    plt.imshow(img_weights.detach().numpy(),cmap="gray")
    counter += 1
    #plt.figure(figsize=(2,2))
    #plt.imshow(img_weights.detach().numpy(),cmap="gray")
    #plt.show()
    #print(item_1.reshape((28,28)))
    #break
  plt.show()
  break
In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k == 0:
    k = 1
    continue
  #print(item.shape)
  img_weights = torch.reshape(item, (8, 16))
  plt.axis('off')
  plt.imshow(img_weights.detach().numpy(),cmap="gray")
  plt.show()
  break
In [ ]:
total_img = 128
num_pr = 16 # number per row

counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 1:
    k += 1
    continue
  #print(item.shape)
  for item_1 in item:
    try:
      #print("shape = ",item_1.shape[0])
      
      img_weights = torch.reshape(item_1, (8, 16))
      
      plt.subplot(num_pr, num_pr, counter)
      plt.axis('off')
      plt.imshow(img_weights.detach().numpy(),cmap="gray")
      counter += 1
      
    except:
      pass
  plt.show()
  break
In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 2:
    k += 1
    continue
  print(item.shape)
  img_weights = torch.reshape(item_1, (8, 16))
  plt.imshow(img_weights.detach().numpy(),cmap="gray")
  plt.show()
  break
torch.Size([128])
In [ ]:
total_img = 128
num_pr = 16 # number per row

counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 3:
    k += 1
    continue
  print(item.shape)
  
  for item_1 in item:
    try:
      #print("shape = ",item_1.shape[0])
      
      img_weights = torch.reshape(item_1, (8, 16))
      
      plt.subplot(num_pr, num_pr, counter)
      plt.axis('off')
      plt.imshow(img_weights.detach().numpy(),cmap="gray")
      counter += 1
      
    except:
      pass
  plt.show()
  
  break
torch.Size([128, 128])
In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 4:
    k += 1
    continue
  print(item.shape)
  img_weights = torch.reshape(item_1, (8, 16))
  plt.imshow(img_weights.detach().numpy(),cmap="gray")
  plt.show()  
  break
torch.Size([128])
In [ ]:
total_img = 784
num_pr = 28 # number per row

counter = 1
plt.figure(figsize=(64*2,32*2))
plt.axis('off')

k = 0
for item in model.parameters():
  if k <= 5:
    k += 1
    continue
  print(item.shape)
  for item_1 in item:
    #print("shape = ",item_1.shape)
    img_weights = torch.reshape(item_1, (8, 16))
    plt.subplot(num_pr, num_pr, counter)
    plt.axis('off')
    plt.imshow(img_weights.detach().numpy(),cmap="gray")
    counter += 1
  plt.show()
  break
torch.Size([784, 128])
In [ ]:
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
  if k <= 6:
    k += 1
    continue
  print(item.shape)
  img_weights = torch.reshape(item, (28, 28))
  plt.axis('off')
  plt.imshow(img_weights.detach().numpy(),cmap="gray")

  break
torch.Size([784])
In [ ]:
torch.save(model.state_dict(), "Mnist_with_dropout.pt")
In [ ]: