| import math |
| import random |
| import torch |
| import numpy as np |
| from torch import nn |
| from torch.nn import functional as F |
| import matplotlib as mpl |
| mpl.use('Agg') |
| import matplotlib.pyplot as plt |
| |
| def net(x, params): |
| x = F.linear(x, params[0], params[1]) |
| x = F.relu(x) |
| |
| x = F.linear(x, params[2], params[3]) |
| x = F.relu(x) |
| |
| x = F.linear(x, params[4], params[5]) |
| return x |
| |
| params = [ |
| torch.Tensor(40, 1).uniform_(-1., 1.).requires_grad_(), |
| torch.Tensor(40).zero_().requires_grad_(), |
| |
| torch.Tensor(40, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(), |
| torch.Tensor(40).zero_().requires_grad_(), |
| |
| torch.Tensor(1, 40).uniform_(-1./math.sqrt(40), 1./math.sqrt(40)).requires_grad_(), |
| torch.Tensor(1).zero_().requires_grad_(), |
| ] |
| |
| opt = torch.optim.Adam(params, lr=1e-3) |
| alpha = 0.1 |
| |
| K = 20 |
| losses = [] |
| num_tasks = 4 |
| def sample_tasks(outer_batch_size, inner_batch_size): |
| # Select amplitude and phase for the task |
| As = [] |
| phases = [] |
| for _ in range(outer_batch_size): |
| As.append(np.random.uniform(low=0.1, high=.5)) |
| phases.append(np.random.uniform(low=0., high=np.pi)) |
| def get_batch(): |
| xs, ys = [], [] |
| for A, phase in zip(As, phases): |
| x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) |
| y = A * np.sin(x + phase) |
| xs.append(x) |
| ys.append(y) |
| return torch.tensor(xs, dtype=torch.float), torch.tensor(ys, dtype=torch.float) |
| x1, y1 = get_batch() |
| x2, y2 = get_batch() |
| return x1, y1, x2, y2 |
| |
| for it in range(20000): |
| loss2 = 0.0 |
| opt.zero_grad() |
| def get_loss_for_task(x1, y1, x2, y2): |
| f = net(x1, params) |
| loss = F.mse_loss(f, y1) |
| |
| # create_graph=True because computing grads here is part of the forward pass. |
| # We want to differentiate through the SGD update steps and get higher order |
| # derivatives in the backward pass. |
| grads = torch.autograd.grad(loss, params, create_graph=True) |
| new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))] |
| |
| v_f = net(x2, new_params) |
| return F.mse_loss(v_f, y2) |
| |
| task = sample_tasks(num_tasks, K) |
| inner_losses = [get_loss_for_task(task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks)] |
| loss2 = sum(inner_losses)/len(inner_losses) |
| loss2.backward() |
| |
| opt.step() |
| |
| if it % 100 == 0: |
| print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) |
| losses.append(loss2) |
| |
| t_A = torch.tensor(0.0).uniform_(0.1, 0.5) |
| t_b = torch.tensor(0.0).uniform_(0.0, math.pi) |
| |
| t_x = torch.empty(4, 1).uniform_(-5, 5) |
| t_y = t_A*torch.sin(t_x + t_b) |
| |
| opt.zero_grad() |
| |
| t_params = params |
| for k in range(5): |
| t_f = net(t_x, t_params) |
| t_loss = F.l1_loss(t_f, t_y) |
| |
| grads = torch.autograd.grad(t_loss, t_params, create_graph=True) |
| t_params = [(t_params[i] - alpha*grads[i]) for i in range(len(params))] |
| |
| |
| test_x = torch.arange(-2*math.pi, 2*math.pi, step=0.01).unsqueeze(1) |
| test_y = t_A*torch.sin(test_x + t_b) |
| |
| test_f = net(test_x, t_params) |
| |
| plt.plot(test_x.data.numpy(), test_y.data.numpy(), label='sin(x)') |
| plt.plot(test_x.data.numpy(), test_f.data.numpy(), label='net(x)') |
| plt.plot(t_x.data.numpy(), t_y.data.numpy(), 'o', label='Examples') |
| plt.legend() |
| plt.savefig('maml-sine.png') |
| plt.figure() |
| plt.plot(np.convolve(losses, [.05]*20)) |
| plt.savefig('losses.png') |