September 13, 2018
by Frank
这篇论文旨在挑战结构搜索,通过将该任务定义成一个可微分的形式,而不是像传统的做法:在离散的不可微分的空间中利用增强学习来实现结构搜索。该方法基于结构表示的连续松弛,允许使用梯度下降等高效的方法进行结构搜索。后续实验表明该算法在探索高性能的用于图像识别的 CNN 结构和语言建模的 RNN 结构上都有良好的表现,并且比现有的 state-of-the-art 非微分结构要快得多。
这篇论文旨在挑战结构搜索,通过将该任务定义成一个可微分的形式,而不是像传统的做法:在离散的不可微分的空间中利用增强学习来实现结构搜索。该方法基于结构表示的连续松弛,允许使用梯度下降等高效的方法进行结构搜索。后续实验表明该算法在探索高性能的用于图像识别的 CNN 结构和语言建模的 RNN 结构上都有良好的表现,并且比现有的 state-of-the-art 非微分结构要快得多。
自动搜索的网络结构在图像分类和目标检测等任务上取得了很有竞争力的表现,然而目前最好的结构搜索算法对于计算能力有非常高的要求。比如利用强化学习获得一个在 CIFAR-10 和 ImageNet 上的 state-of-the-art 结构需要 1800GPU days。现在也有一些加速的方法,比如设置搜索空间/权值/性能预测的特定结构和跨结构的权值共享。但在可扩展性上的基本矛盾依旧存在。这些主流方法(RL,MCTS,SMBO,bayesian 优化)比较低效率的原因在于:将这个结构搜索问题看作一个离散空间上的黑箱优化问题,这就导致了需要大量的结构评估次数。
在这项工作中提出了解决该问题的另外一种方法:darts(Differentiable Architecture Search),将搜索空间进行松弛,使其变得连续,因此这个结构可以通过梯度下降的方法进行优化,使得 darts 可以实现与 state-of-the-art 相竞争的性能,并且只需要少得多的计算量。该方法也超过了另外一个高效的网络搜索方法的性能(ENAS)。darts 比其他很多网络都要更加简单,因为其不包含任何 controller/hypernetworks/performance predictors。目前 darts 在搜索 CNN 和 RNN 上的性能都表现良好。
在连续空间上进行结构搜索的想法并不是最新的想法,但是在这里也有一些主要的区别:前人的工作主要是做 fine -tune,比如 filter shape 等。darts 能够在复杂的有较大搜索空间的图拓扑结构中探索高性能结构,并且能够同时应用于 CNN 和 RNN。实验已经证明,darts 在 CIFAR-10,ImageNet,PTB 上都取得了很好的结果。
本文的主要贡献可以总结为如下几条:
1.介绍了一种可微分的结构搜索算法,可以同时应用在 CNN 和 RNN 上
2.通过在图像分类和语言建模的实验,证明了基于梯度的结构搜索实现了在 CIFAR-10 上的有竞争力的结果,在 PTB 上的性能超过了其他 state-of-the-art 结构,有趣的是目前最好的结构搜索算法利用了不可微的,基于 RL 和 evoluiton 的搜索技巧。
3.实现了很好的结构搜索效率,使用的基于梯度的优化而不是非微分的搜索技术
4.darts 在 cifar-10 和 PTB 上学习的结构可以被迁移到 ImageNet 和 wikitext2 上
这一节首先总体介绍了搜索空间,整个计算流程被描述成一张有向无环图。之后介绍了连接的连续松弛,以及权值和结构的联合优化。最后介绍了一种近似技术来使计算变得可行且高效。
目的是搜索一种计算 cell 来作为最后结构的 building block,可以将其搭建成 CNN 或者循环连接成 RNN。一个 cell 里面是有 N 个节点的有向无环图,其中每个节点相当于一个 feature map,每个有向边相当于一个操作,用于对进行变换,每个节点都是由之前的节点计算而来:
当两个节点之间没有连接时,引入一种空操作,记为操作
令为所有可能操作所构成的集合,其中每个操作表示为,为了让搜索空间变得连续,像 softmax 一样将可行操作的选择在所有操作上进行松弛化
两个节点之间的操作编码为维的向量,进行松弛之后,搜索结构的任务变成了搜索一个集合,该集合也称作对于结构的编码,最后再将混合操作替换为最可能的操作即可,例如可以选择向量中权值最大的对应操作作为选定的操作。
在进行松弛之后,下一个目标就是对于权值和网络结构进行联合优化,与 RL 和 evolution 一样都是使用 valset 的精度作为 reward,不过这里是使用的梯度方法进行优化而不是在离散空间上进行搜索。以分别作为训练集和验证集上的损失,而损失是由共同决定的,这里引出一个双层优化问题:
其实也可以将看作一个超参数,只不过这个超参数的维度要远大于那些标量超参数
直接求解这样的双层优化问题是很困难的,这里给出了一种近似的算法:
其实也就是在初始情况下先固定网络结构,对权值进行优化,然后再在当前权值的基础上对网络结构进行优化,往返进行直到收敛为止,最后再将混合操作确定成某个操作。注意上述公式在更新网络结构时其中的作为一个单步的梯度下降是为了近似,一种相似的方法应用在了 model transfer 的 meta-learning 上,这样的迭代算法在之间定义了一个 Stackelberg game,然而目前并不能保证该算法可以收敛,不过在实际应用中在合适选择的时候是可以收敛的。同时注意到当在权值的优化中使用动量的时候,单步前向算法同样会随之改变,之前的分析仍然是有用的。
在第一步中对的梯度就是普通的神经网络反向传播过程,而在第二步中要对网络结构进行升级,需要计算对的梯度,通过计算可以得到:
其中,是单步前向过程之后的结果,在上面公式的后面一项包括矩阵与向量之间的乘积,这个操作需要大量的计算量,不过可以通过一种近似方法来实现,令为一个极小量:
其中的
可以将计算复杂度从减少为
当的时候,上面公式中的二阶项则不起作用了,整个损失函数对于的梯度则变为,这样的话相当于去掉为了近似下一个而进行的单步传播,这样会提升速度但是会导致更差的性能,在之后将的情况称为一阶近似,而将的情况称为二阶近似
在训练完成了之后需要生成对应的离散结构,首先对于每个中间节点选取 k 个最强的前继节点,在这里对于 CNN,k=2,对于 RNN,k=1。这个 edge 的强度定义为:
然后将每一个混合操作通过取最大的可能值来确定为一个具体操作
在一个简单的 L_val 和 L_train 下进行实验所得到的收敛路径如下图,通过与该问题的解析解的对比,在迭代算法中正确的选择可以收敛到一个更好的结果
在 CIFAR-10 和 PTB 上进行的实验包含两个阶段:结构搜索和结构评估。在第一个阶段搜索可能的 cell 并基于在 val 集上的表现决定最好的 cell 集。在第二个阶段用这些 cell 来组建更大的网络,然后在测试集上来报告其性能。最后利用最好的 cell 来对可迁移性进行了评估。这里的总结主要写 CNN 相关的部分,而不包括 RNN 的内容。
这里只总结卷积相关的部分,操作集包括:3x3 和 5x5 separable conv,3x3 和 5x5 dilated separable conv,3x3 max pooling,3x3 avg pooling,identy,zero。所有的操作是 stride=1,对于卷积操作使用了 ReLU-Conv-BN 的顺序,并且每个 separable conv 都应用了两次。
这里的 conv cell 包括 7 个 node,输出节点定义为内部节点的 depthwise concat,cell k 里面的第一个和第二个 node 与 cell k-1 和 cell k-2 的输出相等,并且在必要的时候使用 1x1 conv。在网络 1/3 和 2/3 部分的 cell 是 reduction cell,在这里面所有与输入节点相连的操作都是 stride=2。于是结构编码分为。
在整个搜索过程中结构会发生变化,在 BN 中使用 batch-specific statistics 而不是 global moving avg,在所有 BN 中都禁止了可学习的仿射参数。
将 CIFAR-10 的训练数据的一半作为 val set,训练了一个 cell=8,50epochs,batch_size=64,initial_channels=16,这些参数设定保证能在一个 gpu 上进行训练,w 的优化使用 momentum SGD,的优化使用 Adam,总共花了一天时间在单个 GPU 上。
为了选取应用于 evaluation 的结构,用不同的随机数种子运行了 4 次 darts,根据验证集的性能选取了最好的 cell。然后再将权值随机初始化后进行训练,再观察其在测试集上的表现。
通过小型网络将 cell 选定之后再训练较大的网络,20 cells,600 epoch,batch_size=96,一些其他的增强措施:cutout,path dropout,auxiliary towers。单个 GPU 上训练了 1.5 天。
darts 在 cifar10 和 imagenet 的实验结果如下图,在 imagenet 上使用了和 cifar10 一样的 cells:
上图中的 NASNet 和 AmoebaNet 是 state-of-the-art,相比而言在效果上差距不大,但是搜索时间少了很多,同样的对于 imagenet 数据集也是一样的效果。
darts 是第一个可微分的结构搜索算法,同时应用于 CNN 和 RNN,通过在连续空间内进行搜索,darts 可以在图像分类和语言建模的任务上面匹配甚至性能超过 state-of-the-art 的非微分结构搜索算法,并且明显更加高效。在未来会利用 darts 探索一种在更大的任务上的直接结构搜索