!pip install torch torchvision
# https://gist.github.com/AFAgarap/4f8a8d8edf352271fa06d85ba0361f26
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
batch_size = 512
epochs = 200
learning_rate = 1e-3
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
class AE(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.encoder_hidden_layer = nn.Linear(
in_features=kwargs["input_shape"], out_features=128
)
self.encoder_output_layer = nn.Linear(
in_features=128, out_features=128
)
self.decoder_hidden_layer = nn.Linear(
in_features=128, out_features=128
)
self.decoder_output_layer = nn.Linear(
in_features=128, out_features=kwargs["input_shape"]
)
def forward(self, features):
activation = self.encoder_hidden_layer(features)
activation = torch.relu(activation)
code = self.encoder_output_layer(activation)
code = torch.sigmoid(code)
activation = self.decoder_hidden_layer(code)
activation = torch.relu(activation)
activation = self.decoder_output_layer(activation)
reconstructed = torch.sigmoid(activation)
return reconstructed
# use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ",device)
# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AE(input_shape=784)
# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# mean-squared error loss
criterion = nn.MSELoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=False, num_workers=4
)
losses = []
for epoch in range(epochs):
loss = 0
for batch_features, _ in train_loader:
# reshape mini-batch data to [N, 784] matrix
# load it to the active device
batch_features = batch_features.view(-1, 784)
# reset the gradients back to zero
# PyTorch accumulates gradients on subsequent backward passes
optimizer.zero_grad()
# compute reconstructions
outputs = model(batch_features)
# compute training reconstruction loss
train_loss = criterion(outputs, batch_features)
# compute accumulated gradients
train_loss.backward()
# perform parameter update based on current gradients
optimizer.step()
# add the mini-batch training loss to epoch loss
loss += train_loss.item()
# compute the epoch training loss
loss = loss / len(train_loader)
# display the epoch training loss
losses.append(loss)
print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
#model = AE(nn.module)
model.load_state_dict(torch.load("Mnist_without_dropout.pt"))
model.eval()
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=10, shuffle=False
)
test_examples = None
with torch.no_grad():
for batch_features in test_loader:
batch_features = batch_features[0]
test_examples = batch_features.view(-1, 784)
reconstruction = model(test_examples)
break
with torch.no_grad():
number = 10
plt.figure(figsize=(20, 6))
for index in range(number):
# display original
ax = plt.subplot(3, number, index + 1)
plt.imshow(test_examples[index].numpy().reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(3, number, index + 1 + number)
plt.imshow(reconstruction[index].numpy().reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display difference
ax = plt.subplot(3, number, index + 2*number + 1)
original = test_examples[index].numpy().reshape(28, 28)
generated = reconstruction[index].numpy().reshape(28, 28)
diff = original - generated
plt.imshow(diff,cmap="viridis")
#plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# do xor and display difference
#ax = plt.subplot(4, number, index + 3*number + 1)
#original = test_examples[index].numpy().reshape(28, 28)
#generated = reconstruction[index].numpy().reshape(28, 28)
# thresh_or = original > 0
# original[thresh_or] = 1
# thresh_gen = generated > 0
# generated[thresh_gen] = 1
# diff_xor = original - generated
# plt.imshow(diff,cmap="viridis")
plt.show()
plt.plot(losses, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Losses');
for item in model.parameters():
print(item.shape)
total_img = 128
num_pr = 16 # number per row
counter = 1
plt.figure(figsize=(30,30))
plt.axis('off')
for item in model.parameters():
for item_1 in item:
img_weights = torch.reshape(item_1, (28, 28))
#print(item_1.shape)
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
#plt.figure(figsize=(2,2))
#plt.imshow(img_weights.detach().numpy(),cmap="gray")
#plt.show()
#print(item_1.reshape((28,28)))
#break
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k == 0:
k = 1
continue
#print(item.shape)
img_weights = torch.reshape(item, (8, 16))
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
plt.show()
break
total_img = 128
num_pr = 16 # number per row
counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 1:
k += 1
continue
#print(item.shape)
for item_1 in item:
try:
#print("shape = ",item_1.shape[0])
img_weights = torch.reshape(item_1, (8, 16))
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
except:
pass
plt.show()
break
item.shape
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 2:
k += 1
continue
print(item.shape)
img_weights = torch.reshape(item_1, (8, 16))
plt.imshow(img_weights.detach().numpy(),cmap="gray")
plt.show()
break
total_img = 128
num_pr = 16 # number per row
counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 3:
k += 1
continue
print(item.shape)
for item_1 in item:
try:
#print("shape = ",item_1.shape[0])
img_weights = torch.reshape(item_1, (8, 16))
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
except:
pass
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 4:
k += 1
continue
print(item.shape)
img_weights = torch.reshape(item_1, (8, 16))
plt.imshow(img_weights.detach().numpy(),cmap="gray")
plt.show()
break
for item in model.parameters():
print(item.shape)
total_img = 784
num_pr = 28 # number per row
counter = 1
plt.figure(figsize=(64*2,32*2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 5:
k += 1
continue
print(item.shape)
for item_1 in item:
#print("shape = ",item_1.shape)
img_weights = torch.reshape(item_1, (8, 16))
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 6:
k += 1
continue
print(item.shape)
img_weights = torch.reshape(item, (28, 28))
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
break
torch.save(model.state_dict(), "Mnist_without_dropout.pt")
class AEDropouts(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.encoder_hidden_layer = nn.Linear(
in_features=kwargs["input_shape"], out_features=128
)
self.drop_hl1 = nn.Dropout(0.5)
self.encoder_output_layer = nn.Linear(
in_features=128, out_features=128
)
self.decoder_hidden_layer = nn.Linear(
in_features=128, out_features=128
)
self.drop_hl2 = nn.Dropout(0.5)
self.decoder_output_layer = nn.Linear(
in_features=128, out_features=kwargs["input_shape"]
)
def forward(self, features):
activation = self.encoder_hidden_layer(features)
activation = torch.relu(self.drop_hl1(activation))
code = self.encoder_output_layer(activation)
code = torch.sigmoid(code)
activation = self.decoder_hidden_layer(code)
activation = torch.relu(self.drop_hl1(activation))
activation = self.decoder_output_layer(activation)
reconstructed = torch.sigmoid(activation)
return reconstructed
# use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ",device)
# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AEDropouts(input_shape=784)
# create an optimizer object
# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# mean-squared error loss
criterion = nn.MSELoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=True, transform=transform, download=True
)
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=False, num_workers=4
)
# cool https://stackoverflow.com/questions/53419474/using-dropout-in-pytorch-nn-dropout-vs-f-dropout
losses = []
for epoch in range(epochs):
loss = 0
for batch_features, _ in train_loader:
# reshape mini-batch data to [N, 784] matrix
# load it to the active device
batch_features = batch_features.view(-1, 784)
# reset the gradients back to zero
# PyTorch accumulates gradients on subsequent backward passes
optimizer.zero_grad()
# compute reconstructions
outputs = model(batch_features)
# compute training reconstruction loss
train_loss = criterion(outputs, batch_features)
# compute accumulated gradients
train_loss.backward()
# perform parameter update based on current gradients
optimizer.step()
# add the mini-batch training loss to epoch loss
loss += train_loss.item()
# compute the epoch training loss
loss = loss / len(train_loader)
# display the epoch training loss
losses.append(loss)
print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
model.load_state_dict(torch.load("Mnist_with_dropout.pt"))
model.eval()
test_dataset = torchvision.datasets.MNIST(
root="~/torch_datasets", train=False, transform=transform, download=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=10, shuffle=False
)
test_examples = None
with torch.no_grad():
for batch_features in test_loader:
batch_features = batch_features[0]
test_examples = batch_features.view(-1, 784)
reconstruction = model(test_examples)
break
with torch.no_grad():
number = 10
plt.figure(figsize=(20, 6))
for index in range(number):
# display original
ax = plt.subplot(3, number, index + 1)
plt.imshow(test_examples[index].numpy().reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(3, number, index + 1 + number)
plt.imshow(reconstruction[index].numpy().reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display difference
ax = plt.subplot(3, number, index + 2*number + 1)
original = test_examples[index].numpy().reshape(28, 28)
generated = reconstruction[index].numpy().reshape(28, 28)
diff = original - generated
plt.imshow(diff,cmap="viridis")
#plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
plt.plot(losses, '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.title('Losses');
for item in model.parameters():
print(item.shape)
total_img = 128
num_pr = 16 # number per row
counter = 1
plt.figure(figsize=(30,30))
plt.axis('off')
for item in model.parameters():
for item_1 in item:
img_weights = torch.reshape(item_1, (28, 28))
#print(item_1.shape)
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
#plt.figure(figsize=(2,2))
#plt.imshow(img_weights.detach().numpy(),cmap="gray")
#plt.show()
#print(item_1.reshape((28,28)))
#break
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k == 0:
k = 1
continue
#print(item.shape)
img_weights = torch.reshape(item, (8, 16))
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
plt.show()
break
total_img = 128
num_pr = 16 # number per row
counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 1:
k += 1
continue
#print(item.shape)
for item_1 in item:
try:
#print("shape = ",item_1.shape[0])
img_weights = torch.reshape(item_1, (8, 16))
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
except:
pass
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 2:
k += 1
continue
print(item.shape)
img_weights = torch.reshape(item_1, (8, 16))
plt.imshow(img_weights.detach().numpy(),cmap="gray")
plt.show()
break
total_img = 128
num_pr = 16 # number per row
counter = 1
plt.figure(figsize=(32,16))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 3:
k += 1
continue
print(item.shape)
for item_1 in item:
try:
#print("shape = ",item_1.shape[0])
img_weights = torch.reshape(item_1, (8, 16))
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
except:
pass
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 4:
k += 1
continue
print(item.shape)
img_weights = torch.reshape(item_1, (8, 16))
plt.imshow(img_weights.detach().numpy(),cmap="gray")
plt.show()
break
total_img = 784
num_pr = 28 # number per row
counter = 1
plt.figure(figsize=(64*2,32*2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 5:
k += 1
continue
print(item.shape)
for item_1 in item:
#print("shape = ",item_1.shape)
img_weights = torch.reshape(item_1, (8, 16))
plt.subplot(num_pr, num_pr, counter)
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
counter += 1
plt.show()
break
counter = 1
plt.figure(figsize=(2,2))
plt.axis('off')
k = 0
for item in model.parameters():
if k <= 6:
k += 1
continue
print(item.shape)
img_weights = torch.reshape(item, (28, 28))
plt.axis('off')
plt.imshow(img_weights.detach().numpy(),cmap="gray")
break
torch.save(model.state_dict(), "Mnist_with_dropout.pt")