手把手 | OpenAI开发可拓展元学习算法Reptile,能快速学习(附代码)
大数据文摘作品
编译:Zoe Zuo、丁慧、Aileen
本文来自OpenAI博客,介绍一种新的元学习算法Retile。
在OpenAI, 我们开发了一种简易的元学习算法,称为Reptile。它通过对任务进行重复采样,利用随机梯度下降法,并将初始参数更新为在该任务上学习的最终参数。
其性能可以和MAML(model-agnostic meta-learning,由伯克利AI研究所研发的一种应用广泛的元学习算法)相媲美,操作简便且计算效率更高。
MAML元学习算法:
http://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/
元学习是学习如何学习的过程。此算法接受大量各种的任务进行训练,每项任务都是一个学习问题,然后产生一个快速的学习器,并且能够通过少量的样本进行泛化。
一个深入研究的元学习问题是小样本分类(few-shot classification),其中每项任务都是一个分类问题,学习器在每个类别下只能看到1到5个输入-输出样本(input-output examples),然后就要给新输入的样本进行分类。
下面是应用了Reptile算法的单样本分类(1-shot classification)的互动演示,大家可以尝试一下。
尝试单击“Edit All”按钮,绘制三个不同的形状或符号,然后在右侧的输入区中绘制其中一个,并查看Reptile如何对它进行分类。前三张图是标记样本,每图定义一个类别。最后一张图代表未知样本,Reptile要输出此图属于每个类别的概率。
Reptile的工作原理
像MAML一样,Reptile试图初始化神经网络的参数,以便通过新任务产生的少量数据来对网络进行微调。
但是,当MAML借助梯度下降算法的计算图来展开和区分时,Reptile只是以标准方法在每个任务中执行随机梯度下降(stochastic gradient descent, SGD)算法,并不展开计算图或者计算二阶导数。这使得Reptile比MAML需要更少的计算和内存。示例代码如下:
初始化Φ,初始参数向量
对于迭代1,2,3……执行
随机抽样任务T
在任务T上执行k>1步的SGD,输入参数Φ,输出参数w
更新:Φ←Φ+ϵ(w−Φ)
结束
返回Φ
最后一步中,我们可以将Φ−W作为梯度,并将其插入像这篇论文里(https://arxiv.org/abs/1412.6980)Adam这样更为先进的优化器中作为替代方案。
首先令人惊讶的是,这种方法完全有效。如果k=1,这个算法就相当于 “联合训练”(joint training)——对多项任务的混合体执行SGD。虽然在某些情况下,联合训练可以学习到有用的初始化,但当零样本学习(zero-shot learning)不可能实现时(比如,当输出标签是随机排列时),联合训练就几乎无法学习得到结果。
Reptile要求k>1,也就是说,参数更新要依赖于损失函数的高阶导数实现,此时算法的表现和k=1(联合训练)时是完全不同的。
为了分析Reptile的工作原理,我们使用泰勒级数(Taylor series)来逼近参数更新。Reptile的更新将同一任务中不同小批量的梯度内积(inner product)最大化,从而提高了的泛化能力。
这一发现可能超出了元学习领域的指导意义,比如可以用来解释SGD的泛化性质。进一步分析表明,Reptile和MAML的更新过程很相近,都包括两个不同权重的项。
泰勒级数:
https://en.wikipedia.org/wiki/Taylor_series
在我们的实验中,展示了Reptile和MAML在Omniglot和Mini-ImageNet基准测试中对少量样本分类时产生相似的性能,由于更新具有较小的方差,因此Reptile也可以更快的收敛到解决方案。
Omniglot:
https://github.com/brendenlake/omniglot
Mini-ImageNet:
https://arxiv.org/abs/1606.04080
我们对Reptile的分析表明,通过不同的SGD梯度组合,可以获得大量不同的算法。在下图中,假设针对每一任务中不同小批量执行k步SGD,得出的梯度分别为g1,g2,…,gk。
下图显示了在 Omniglot 上由梯度之和作为元梯度而绘制出的学习曲线。g2对应一阶MAML,也就是原先MAML论文中提出的算法。由于方差缩减,纳入更多梯度明显会加速学习过程。需要注意的是,仅仅使用g1(对应k=1)并不会给这个任务带来改进,因为零样本学习的性能无法得到改善。
X坐标:外循环迭代次数
Y坐标:Omniglot对比5种方式的
5次分类的准确度
算法实现
我们在GitHub上提供了Reptile的算法实现,它使用TensorFlow来完成相关计算,并包含用于在Omniglot和Mini-ImageNet上小样本分类实验的代码。我们还发布了一个较小的JavaScript实现,对TensorFlow预先训练好的模型进行了微调。文章开头的互动演示也是借助JavaScript完成的。
GitHub:
https://github.com/openai/supervised-reptile
较小的JavaScript实现:
https://github.com/openai/supervised-reptile/tree/master/web
最后,展示一个小样本回归(few-shot regression)的简单示例,用以预测10(x,y)对的随机正弦波。该示例基于PyTorch实现,代码如下:
import numpy as np
import torch
from torch import nn, autograd as ag
import matplotlib.pyplot as plt
from copy import deepcopy
seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it
rng = np.random.RandomState(seed)
torch.manual_seed(seed)
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
"Generate classification problem"
phase = rng.uniform(low=0, high=2*np.pi)
ampl = rng.uniform(0.1, 5)
f_randomsine = lambda x : np.sin(x + phase) * ampl
return f_randomsine
# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results
model = nn.Sequential(
nn.Linear(1, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1),
)
def totorch(x):
return ag.Variable(torch.Tensor(x))
def train_on_batch(x, y):
x = totorch(x)
y = totorch(y)
model.zero_grad()
ypred = model(x)
loss = (ypred - y).pow(2).mean()
loss.backward()
for param in model.parameters():
param.data -= innerstepsize * param.grad.data
def predict(x):
x = totorch(x)
return model(x).data.numpy()
# Choose a fixed task and minibatch for visualization
f_plot = gen_task()
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]
# Reptile training loop
for iteration in range(niterations):
weights_before = deepcopy(model.state_dict())
# Generate task
f = gen_task()
y_all = f(x_all)
# Do SGD on this task
inds = rng.permutation(len(x_all))
for _ in range(innerepochs):
for start in range(0, len(x_all), ntrain):
mbinds = inds[start:start+ntrain]
train_on_batch(x_all[mbinds], y_all[mbinds])
# Interpolate between current weights and trained weights from this task
# I.e. (weights_before - weights_after) is the meta-gradient
weights_after = model.state_dict()
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
model.load_state_dict({name :
weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize
for name in weights_before})
# Periodically plot the results on a particular task and minibatch
if plot and iteration==0 or (iteration+1) % 1000 == 0:
plt.cla()
f = f_plot
weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))
for inneriter in range(32):
train_on_batch(xtrain_plot, f(xtrain_plot))
if (inneriter+1) % 8 == 0:
frac = (inneriter+1) / 32
plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))
plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
lossval = np.square(predict(x_all) - f(x_all)).mean()
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
plt.ylim(-4,4)
plt.legend(loc="lower right")
plt.pause(0.01)
model.load_state_dict(weights_before) # restore from snapshot
print(f"-----------------------------")
print(f"iteration {iteration+1}")
print(f"loss on plotted curve {lossval:.3f}") # would be better to average loss over a set of examples, but this is optimized for brevity
论文链接:
https://arxiv.org/abs/1803.02999
代码链接:
https://github.com/openai/supervised-reptile
原文链接:
https://blog.openai.com/reptile/
【今日机器学习概念】
Have a Great Definition
志愿者介绍
回复“志愿者”加入我们
关注公众号:拾黑(shiheibook)了解更多
[广告]赞助链接:
四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/
随时掌握互联网精彩
- 1 习近平拉美之行的三个“一” 7968509
- 2 微信或史诗级“瘦身” 内存有救了 7952553
- 3 男子求助如何打开亡父遗留14年手机 7887185
- 4 中国主张成为G20峰会的一抹亮色 7714061
- 5 中国对日本等国试行免签 7646629
- 6 7万余件儿童羽绒服里没有真羽绒 7572840
- 7 女生半裸遭男保洁刷卡闯入 酒店回应 7402534
- 8 70多辆小米SU7同一天撞墙撞柱 7315618
- 9 操纵股价 2人被证监会罚没近3.35亿 7295885
- 10 千年古镇“因网而变、因数而兴” 7136633