Lecture 18: Failures in data privacy#

Learning objectives#

By the end of this lecture, students should be able to:

  • Understand why encrypted data might not be as private as we thought

  • Understand how anonymized data can be reverse-engineered

  • Understand how models can be inverted to extract input data

Slides#

Note

Download a PDF version here

Example of model inversion in neural network using MNIST data#

# !pip install torch torchvision torchaudio

First, we train a NN with two parts to simulate the Split Neural Networks (SplitNN), which is a training paradigm in which part of a network is hosted on a data holder’s device and the second part of the network is hosted on another device.

from torch import nn, optim

class SplitNN(nn.Module):
    def __init__(self):
        super(SplitNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Linear(784, 500),
            nn.ReLU(),
        )
        self.second_part = nn.Sequential(
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 10),
            nn.Softmax(dim=-1),
        )

    def forward(self, x):
        return self.second_part(self.first_part(x))

target_model = SplitNN()

Let’s assume the target model has been trained on the MNIST dataset and that we can access the size 500 vector output from the model’s first_part

class Attacker(nn.Module):
    def __init__(self):
        super(Attacker, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(500, 1000),
            nn.ReLU(),
            nn.Linear(1000, 784),
        )

    def forward(self, x):
        return self.layers(x)

Create data loader

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations for the EMNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the EMNIST dataset
emnist_train = datasets.EMNIST(root='data', split='balanced', train=True, download=True, transform=transform)
emnist_train_loader = DataLoader(emnist_train, batch_size=64, shuffle=True)
Downloading https://biometrics.nist.gov/cs_links/EMNIST/gzip.zip to data/EMNIST/raw/gzip.zip
100%|██████████| 562M/562M [00:20<00:00, 27.6MB/s] 
Extracting data/EMNIST/raw/gzip.zip to data/EMNIST/raw
# Load the MNIST dataset
mnist_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)

mnist_test = datasets.MNIST(root='data', train=False, download=True, transform=transform)
mnist_test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

Perhaps we don’t know exactly what data the target model has been trained on, but we do know that it’s some sort of handwritten images.

Therefore we can use part of the EMNIST dataset of handwritten letters to train our attacker.

# Initialize the attacker model and optimizer
attacker = Attacker()
optimiser = optim.Adam(attacker.parameters(), lr=1e-4)

# Training loop
for data, targets in emnist_train_loader:
    # Flatten the input data
    data = data.view(data.size(0), -1)
    
    # Reset gradients
    optimiser.zero_grad()

    # First, get outputs from the target model
    target_outputs = target_model.first_part(data)

    # Next, recreate the data with the attacker
    attack_outputs = attacker(target_outputs)

    # We want attack outputs to resemble the original data
    loss = ((data - attack_outputs) ** 2).mean()

    # Update the attack model
    loss.backward()
    optimiser.step()

Let’s plot our original data against the recreated data

import matplotlib.pyplot as plt
# Testing loop
for data, targets in mnist_test_loader:
    # Flatten the input data
    data = data.view(data.size(0), -1)

    # Get outputs from the target model
    target_outputs = target_model.first_part(data)

    # Recreate the data with the attacker
    recreated_data = attacker(target_outputs)

    # Plot the first 2 original and recreated images for comparison
    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    for i in range(2):
        # Original image
        axes[i, 0].imshow(data[i].view(28, 28).cpu().detach().numpy(), cmap='gray')
        axes[i, 0].set_title(f"Original {i+1}")
        axes[i, 0].axis('off')

        # Recreated image
        axes[i, 1].imshow(recreated_data[i].view(28, 28).cpu().detach().numpy(), cmap='gray')
        axes[i, 1].set_title(f"Recreated {i+1}")
        axes[i, 1].axis('off')

    plt.show()
    break  # Only plot for the first batch
../_images/9eb3e5aae990de507efae6c41f9593e05d117f2dcdcb6380188fb9a3e53c5e68.png