开源代码:
作者个人主页:
实验室主页:
背景
去年一次组会上,在和导师们讨论未来的 research 方向的时候,偶然聊到一个问题:
视频网站的视频播放会自动根据网络带宽调整画质,如网速好的时候到4K,网速差就720P甚至更低。那同一个神经网络能不能随时根据计算资源的变化调整推理速度?
从 2012 的 AlexNet 到 2023 年火出圈的 ChatGPT, AI/ML 这一社区在十年间少说已经训练了上百万个模型。截至这篇文章写作时,HuggingFace 上可以直接下载的模型就有 14 万个,涵盖各个模态和任务。每个模型各司其职,用自己在训练中学到的知识去处理某一种场景,互不叨扰。
模型虽然越来越多,但是资源浪费也越来越严重。训练一个模型的成本很高,尤其是大模型训练,耗费数个节点和几天的算力才能得到一个好权重,但最后却受限于应用场景只能重新调整结构,然后再重新训练,如网络 backbone 设计中通常会有不同 scale 来满足不同的推理速度要求: ResNet-18/50/101,DeiT-Ti/S/B,Swin-Ti/S/B 等等。
传统方法当然能加速模型推理,如 pruning,distillation,quantization。但问题是这些方法一次大都只能针对一个模型,一个资源场景。我们也可以用 NAS 搜出来若干个子网络来满足不同推理速度需求,即使如此,NAS 中训练一个 Supernet 的成本也是巨大的,典型的如 OFA 和 BigNAS,花费上千 GPU hours 才得到一个好网络,资源消耗巨大。
看着 huggingface 上这么大的 model zoo,我们不禁想,整个社区花了大量时间,金钱和人力资源去训练网络,得到了这么多的 pretrained model,但是能不能有效利用起来?况且这些模型已经训练好了,当需要他们的时候,能不能用少量计算资源就可以满足目标场景?
对这一问题的思考也是随着模型被工业界越推越大引出的。几年前一张 1080 就能跑完的实验,现在 8 张卡都很难 train 得动一个 model,特别是 Transformer 出来之后。最新的 ViT 已经 scale 到 22B,BAAI 的 EVA 也把 ViT 扩展到了 1B 的参数级别。留给小组的空间越来越小,在资源有限(缺卡)的场景下,我们需要寻求新的突破方向。
Stitchable Neural Networks
Industry 和 Academia 所关注的问题可以有些区别。既然大模型不是所有人能做得起的,那我们不如去利用好已有的 pretrained model。现在我们有了一组训练好的 model family,如 DeiT-Tiny/Small/Base。不同模型有不同大小,推理速度,显存占用。那么能不能利用这些已有的 weights 和结构快速得到一批新网络来满足不同的资源场景?
我们在 CVPR 2023 最新的工作 Stitchable Neural Network (SN-Net) 给出了一个非常具有潜力的方案。
SN-Net 的主要思想是:在一组已经训练好的 model family 中插入若干个 stitching layer (即 1x1 conv), 使得 forward 时 activation 可以在模型间的不同位置游走。当模型在不同位置缝合的时候,一个个新网络结构就出来了!!!
此时,我们把原先 model family 中的网络叫做 anchors,缝合出来的新网络叫做 stitches。单个 SN-Net 可以 cover 众多 FLOPs-accuracy 的 trade-off,如在基于 Swin 的实验中,一个 SN-Net 的可以挑战 timm 中 200 个独立的模型,整个实验不过是 50 epochs,八张 V100 上训练不到一天。
下面会介绍详细的做法,以及我们当时方法设计时候的考虑。想直接看效果的朋友可以移步最后的结果展示。
1. 模型这么多,怎么去选择
这里主要考虑了几个地方:
不同模型结构在网络中各层学习到的 representation 会有较大差别,缝合出来的网络不一定保证较好的 performance;
不同数据集学到的东西差别也很大,为了保证性能最好保持在相同 pretrained 的 dataset 下;
不同网络的实现和训练方式有差别,工程上很难权衡超参和 data augmentation 的选择。
而同一个结构通常在一个 repo 里,更容易实现。 因此,我们初步关注在相同 dataset 上训练好的 model family 上, 即结构相似,但是模型 scale 不一样,如 DeiT-Ti/S/B。
不同 family 能不能缝合?也能,我们 paper 里有展示结果,但是工程上会比较麻烦,需要 combine 不同 repo 并且权衡超参。
2. 怎么去做缝合?
model stitching 在原先工作中大都是以研究 representation similarity 的形式呈现的,如:
Lenc, Karel, and Andrea Vedaldi. "Understanding image representations by measuring their equivariance and equivalence." CVPR 2015.
Kornblith, Simon, et al. "Similarity of neural network representations revisited." ICML, 2019.
Csiszárik, Adrián, et al. "Similarity and matching of neural network representations." NeurIPS 2021.
总结过去这些工作:同一个网络,用不同 seed 训练之后可以在某些位置缝合起来,此时性能不会掉的很离谱。后续的研究发现结构不一样的网络甚至也能缝合。
而 stitching 能够 work 在于,假设前一个网络出来的 feature map 属于 activation 空间 A,而另一个网络在此位置的输入 feature map 属于 activation 空间 B,那么 stitching layer 做的事情就是把 feature map 从 A 空间映射到 B 空间,使得此时的 feature map 能模拟下一网络在这个位置的输入。
当网络是已经是 pretrained,那么 stitching 这一过程完全可以 formulate 成一个求解 least squares 的问题。也就是说 stitching layer 这个 weights 的 matrix 是可以直接求出来的 (参考 Csiszárik, Adrián, et al 这篇)。所以此时求解出来的 matrix 可以天然作为 stitching layer 的初始化。
3. 缝合方向的设定
现在我们有一个大模型:性能好但是推理速度慢,还有一个小模型:性能差点但是推理速度快。我们怎么决定谁 stitch 到谁呢?我们主要考虑了两个方面:
参考当前 backbone 设计的惯例,随着网络不断深入,channel dimension 是在不断增大的。Fast-to-Slow 这方向比较符合常见的网络设计;
实验验证 Fast-to-Slow 得到的 curve 要比 Slow-to-Fast 要 smooth 一点,详见论文。
所以目前 SN-Net 在方向上是从小模型缝合到大模型。同时我们提出一个 constraint: nearest stitching,限制 stitching 只在复杂度 (FLOPs) 相邻的两个 anchor 之间。如补充材料中的 Figure 10 所示,以 DeiT-Ti/S/B 为例,我们的方法目前限制在 (a), (b) 两个 case。
这个限制是因为我们发现 anchor 的 gap 比较大的时候,缝合出来的网络并不在一个 optimal 的区间。实验部分也证明直接 stitch DeiT-Ti 和 DeiT-B 效果不如中间加一个 DeiT-S。
4. 怎么配置Stitching Layer
网络设计地千奇百怪,怎么去缝合是个问题。
我们以 DeiT 为例,在相同 depth 的缝合实验上采取了 Paired Stitching 这种策略。这种策略的启发来自于过去一些工作发现:相邻 layer 之间的 representation 是有较高的相似度的。所以我们选择在 DeiT 得相邻 blocks 中 share 同一个 stitching layer,如滑窗一般进行 stitching。
share 的情况下,原先的初始化方法就是简单地对不同 solution 得到的 matrix 做一个 average。选择 share stitching layer 还有其他好处,如减少过多 stitching layer 带来的参数量,同时扩大缝合出来的结构数量,即扩大 stitching space。
另外一种情况是两个模型的 depth 不一样,小模型一般比较浅,block 的数量要比大模型少。比如 Swin-Ti 的第三个 stage 只有 6 个 block,而 Swin-S 在第三个 stage 有 18 个 block。此时我们进行 Unpaired Stitching,每个小模型的 block 都 stitch 到大模型的若干个 block 中。这样两个 case 就都解决了。
5. SN-Net能缝出来多少网络?
这个由多种因素决定。
看选择的 model family,即 anchors 的 depth。显然 anchor 越深,那么能 stitch 的位置就越多,新网络结构也会更多。
相同 depth 下看 stitching 时 sliding window 的设置。
不加 nearest stitching 的时候得到的网络更多 (DeiT 上的实验是十倍的差距,71 vs. 731)。但是此时不 optimal。后续潜力尚待挖掘。
对比 NAS 中级别的s earch space,SN-Net 在基于同一组 model family 得到的网络数量是有限的。但有一点不得不提,纵使 search space 再大,真正需要的时候也只是用 pareto frontier 上的网络结构,而 SN-Net 缝合出来的网络几乎天然落在 pareto frontier 上,同时部署的时候完全可以直接查表,几乎没有什么 search cost。
另外一点是,SN-Net 的潜力在于整个 pretrained model zoo。有多少 model familiy,就有多少潜在的 SN-Net 变种。这是 NAS 的单一 supernet 所不能比拟的。
这意味着我们可以轻易缝合已有的 model family 达到 NAS 耗费大量计算资源搜出来的网络性能,比如简单缝合两个 LeViT 就可以用更低的 FLOPs(977M vs. 1040M)达到媲美 BigNASModel-XL 的性能(80.7% vs. 80.9%),如下图所示。
6. 简单的训练策略
训练 SN-Net 尤为简单。先提前把所有需要训练的 stitches 定义好,训练中每次 iteration 都随机 sample 出来一个 stitch,后面和正常的训练一样进行 loss回传,梯度下降。为了进一步提升 stitches 的性能,我们初步实验同时采用了 RegNetY-160 作为 teacher model 去做 distillation。
结果展示
为了验证 Joint Training 和原有网络从头 train 的差距,我们选择了若干个和 stitches 相同的网络结构,然后在 ImageNet 上训满 300 epochs。从下表可以看到,对比用了大量计算资源训练出来的网络,SN-Net 利用已有的 DeiT family 只用 50 个 epoch 就可以得到比肩甚至更好的性能。同时整个网络只要 118.4M 的参数,而这 71 个 stitches 的总量如果单独训练需要 2630M,耗费 71 × 300 epochs,和 SN-Net 比是 22 倍的差距。
基于 DeiT 和 Swin Transformer,我们验证了缝合 plain ViT 和 hierarchical ViT 的可行性。性能曲线如在 anchors 中进行插值一般。
值得一提的是,图中不同点所表示的子网络,即 stitch,是可以在运行时随时切换的。这意味着网络在 runtime 完全可以依靠查表进行瞬时推理速度调整。这个是诸多网络无法实现的,但颇具现实意义。比如现在很多手机都有省电模式,一旦进行 power saving, 手机掉帧,系统运行速度变慢,而此时 neural network 也可以调整推理速度,做一个 speed-accuracy 的 trade-off。
我们当然也尝试了 stitch cnn,甚至不同的 family,结果非常 promising。
更多实验内容和分析请移步我们的 arxiv 论文。
SN-Net的可扩展空间
SN-Net 生于 large model zoo 的时代。我们初版方法给出了一个最简单的 baseline,相信未来有很大的扩展空间,比如:
1. 当前的训练策略比较简单,每次 iteration sample 出来一个 stitch,但是当 stitches 特别多的时候,可能导致某些 stitch 训练的不够充分,除非增加训练时间。所以训练策略上可以继续改进;
2. anchor 的 performance 会比之前下降一些,虽然不大。直觉上,在 joint training 过程中,anchor 为了保证众多 stitches 的性能在自身 weights 上做了一些 trade-off。目前补充材料里发现 finetune 更多 epoch 可以把这部分损失补回来;
3. 不用 nearest stitching 可以明显扩大 space,但此时大部分网络不在 pareto frontier 上,未来可以结合训练策略进行改进,或者在其他地方发现 advantage;
4. 未来能否有个更好方法和统一的框架去缝合任意网络。到那时,整个 model zoo 就像积木一样,可操作空间更大,玩法更多,这一点 NUS 的 Xingyi Yang 之前有尝试,参考 Deep Model Reassembly.
更多探索就留给 future work 了。代码已经开源,硬件要求十分友好,50 个 epoch(用 8 卡 V100 大约半天时间)就可以复现结果。欢迎有兴趣的同学进行尝试。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。