这是一个将programmable switch当成一个分布式机器学习的parameter server,从而实现line rate的梯度聚集的工作。它是第一个提出将programmable switch用在分布式机器学习的梯度聚集上的工作,引用量也比较多。和parameter server相比,可以避免end-host processing,达到“sub-RTT” latency;和allreduce相比,可以将传输数据量减少大约2倍。和一些梯度压缩的技术相比也不一定慢,同时还有无损梯度聚集的优点。
它有三个创新点:
- 为了适应programmable switch的弱算力,文章提出将参数的更新分成小块,从而可以被programmable switch流水线处理。这个严格来说可能不算创新,而是一种adaptation。逻辑上的设计都是非常简单的。
这个算法虽然很简单,但是其实很巧妙。将需要聚集的分块后,worker和switch的分块是一个有序的直接相联映射,它隐含了多个worker之间的共识:收到switch的返回就说明梯度聚集完成,并且下一个需要聚集的分块的位置也是确定且一致的。
The coordination is implicit because the mapping between model updates, slots, and packets is deterministic.
- 为了保证worker之间的同步,以及检测和恢复丢包的问题,控制面需要能够区分是回传正确结果时的丢包还是worker传数据的时的丢包。文章的解决方案是保存两个switch状态:上一次slot里面的结果的shadow copy,以及当前的slot接收到的worker发来的数据。这个机制只能应对worker落后一个chunk的情况,如果worker持续掉线的话是会阻塞的,而分布式机器学习也不会像一般的分布式系统那样考虑太多的可用性,就简单阻塞住就行。它给的switch端算法逻辑9-17行好像有点问题,大概理解就行。
- 为了适应programmable switch不支持浮点运算,文章提出在switch上做定点数加法,然后在worker端做scaling。worker首先要对这个scaling的幂值达成共识,文章用的方法是在发一个block的时候带上下一个block的梯度的最大幂值,然后在switch返回的时候带上全局最大的幂值给每个worker。第一个block的幂值怎么办文章没说。
implementation部分还有一些P4和RDMA的具体实现的讨论,但我暂时不太懂,这部分比较复杂,可能实际用到的时候才有动力去学。extention部分讨论了这个框架可以怎样扩展到多机架(跨交换机)的场景,引入拥塞控制模块,不同的硬件设备以及多租户的集群环境。