关于pytorch中使用detach并不能阻止参数更新这档子事儿

2022 年 2 月 3 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨惠会绘@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/344916574
编辑丨极市平台

极市导读

 

文章分三个部分解释了pytorch中detach和参数更新的问题:1)detach后参数依旧更新的原因 2)step函数的操作 3)如何达到之前提到的效果。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

最近做实验,虽然没做出什么有价值的成果,但是对pytorch中detach和参数更新有了进一步的认识。众所周知,detach能够阻断梯度回传。首先来看一下detach的基本用法, 下面这个模型有两层线性变换组成,中间使用relu激活:


   
   
     
class TestDetach(nn.Module):    def __init__(self, InDim, HiddenDim, OutDim):        super().__init__()        self.layer1 = nn.Linear(InDim, HiddenDim, False)        self.layer2 = nn.Linear(HiddenDim, OutDim, False)
def forward(self, x, DetachLayer1): x = torch.relu(self.layer1(x)) x = x.detach() x = self.layer2(x) return x

第一层的输出后detach,那么第一层的参数将永远不会更新。然而在一些情况下单纯的detach可能并不能达到固定参数的作用,这也就是这篇文章要谈的:在哪些情况下单纯的detach不能阻断参数的更新,然后应该如何去做。

首先来看一个具体的场景:还是上面的模型,它有两个参数layer1的权重矩阵,记为weight1和layer2的权重矩阵,记为weight2,但是我并不想让他们同步更新,而是weight2每更新N次,weight1才更新一次(先不讨论这样做有没有实际的应用意义)。一开始想的很简单,通过detach来阻断梯度的计算,从而阻止参数更新。模型的代码如下:


   
   
     
class TestDetach(nn.Module):    def __init__(self, InDim, HiddenDim, OutDim):        super().__init__()        self.layer1 = nn.Linear(InDim, HiddenDim, False)        self.layer2 = nn.Linear(HiddenDim, OutDim, False)
def forward(self, x, DetachLayer1): # 多传了一个bool参数,以指示是否detach x = torch.relu(self.layer1(x)) if DetachLayer1: x = x.detach() x = self.layer2(x) return x

前向传播的时候为模型传递一个bool参数DetachLayer1,当其为真的时候对第一层的输出进行detach,以阻止梯度的反向传播。之后是训练的代码:


   
   
     
def train():    # 随机生成数据    N, F = 5000, 196    x = torch.Tensor(np.random.randn(N, F))    y = torch.LongTensor(np.random.randint(0, 3, (N,)))    # 生成损失函数,模型,优化器    LossFunc = nn.CrossEntropyLoss()    model = TestDetach(F, 64, 3)    optimizer = torch.optim.Adam(model.parameters(), lr=1, weight_decay=0.5)    # 训练    for epoch in range(100):        # 将DetachLayer1交替置为True和False        DetachLayer1 = False if epoch % 2 == 0 else True        Yhat = model(x, DetachLayer1)        Loss = LossFunc(Yhat, y)        print(f"Epoch {epoch}, DetachLayer1: {DetachLayer1}, Loss: {Loss:.4f}")        Loss.backward()        optimizer.step()        optimizer.zero_grad()

猛地一看并没有什么问题,我们控制bool参数的值,每次进行翻转,epoch为奇数,detach有效,参数不更新;反之,偶数则更新,以使得weight1能够隔一次更新一次。然而,实际运行起来,并不能达到我想要的效果:

如果你亲自单步调试的话,会发现第一层的参数再一直变化,并没有被固定住,就像是detach没有起到作用一样。进一步观察第一层变量的梯度的话,你会发现梯度确实也是0,但是参数确实也在更新:

这时候,如果没有意识到真正的问题所在的话,很可能以为是pytorch出了什么bug,或者如果你的模型很复杂,就会很容怀疑是不是模型哪里写错了。然而,问题不在模型,而在于优化器,那么下面我们就来仔细看看。当然, 如果你已经意识到了问题出在哪里,那么我估计也不用花时间往下看了,因为问题完全是因为基础知识不牢固而导致的一个愚蠢的问题,除此之外或许你也可以划到后面简单看看step函数和zero_grad的操作,或许会有一点儿收获。

下面文章分为三个部分:1)detach后参数依旧更新的原因 2)step函数的操作 3)如何达到之前提到的效果

1)detach后参数依旧更新的原因

谜底揭晓,其实就一句话:参数更新并不只依靠当前梯度。如果你知道有一些优化器更新参数的时候会利用历史信息,那就很好理解了。detach确实有用,确实是没有计算当前回传时的梯度,但是优化器可能利用了历史的梯度信息对参数进行了更新。比如我们这里使用的adam:

可以看出(或者没有看出,但是问题不大,直接告诉你):它有累计一阶动量和二阶动量,简单地说也就是,虽然当前的梯度为零,但是因为之前计算过梯度,所以累积量不为零,每次迭代仍然能利用累积量更新参数。

那么更换优化器为最朴素的SGD(默认动量系数为0):

   
   
     
   optimizer = torch.optim.SGD(model.parameters(), lr=1)

我们会发现确实达到了我们想要的效果:第一层的参数在detach的情况下也不会更新了,之后可以将动量的参数"momentum"设置为非零,那么也会观察到我们前面提到的:detach时,也会有参数更新。值得一提的是,pytorch的优化器自带weight_decay参数,来为参数提供正则化的约束(可以简单理解为参数每次更新都有一个递减量),如果这个参数非零,也会出现之前的问题。

那问题又来了,如果第一层参数一直被detach的话,第一层的参数就不会变,也就是weight_decay并不会影响第一层的参数;但是为啥更新一次之后,再detach,weight_decay就会影响到明明已经被detach了的参数呢?这个要从step的操作说起了。

2)step函数的操作

pytorch中参数更新的一个套路性操作:backward,step, zero_grad。backward计算梯度,step根据当前计算的梯度和历史更新记录去更新参数,zero_grad将梯度置零,以防止梯度累计。下图单步展示了三者的功能:

值得一提的是:上面zero_grad将梯度清零,而不是置为之前的None。然后,我们跳转到step的实现部分,注意观察第62行:

只要参数的grad为None,那么优化器就什么都不做,直接跳过这个参数;然而如果不是None,即使是全零,优化器仍然会进行后续的操作,以更新参数。所以一个参数如果一开始没有被detach,使得它的梯度由None转化为tensor,那么之后即使再被detach,参数很可能依旧被更新(被历史梯度的累计或者是weight_decay更新)。

3)如何解决

我最开始的想法是使用两个优化器分别去更新两层的参数。但是如果仔细分析一下就能发现,导致这个问题的原因是,zero_grad将计算得的梯度变为全零张量,而不是变为之前的None。只要将梯度再次转化为None,那么就会避免历史信息和weight_decay对参数的更新了。pytorch已经考虑到了这一点,调用zero_grad时可以指定参数:set_to_none,将其置为True即可。

4)结束

对detach和参数更新的一点认识,欢迎讨论交流和批评指正。

其他文章,欢迎阅读:

惠会绘:关于把图神经网络做到大图(large-scale)上这档事儿

https://zhuanlan.zhihu.com/p/345233657

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取 最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛: 珠港澳人工智能算法大赛 保姆级零基础人工智能教程
算法trick 目标检测比赛中的tricks集锦 从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述: 一文弄懂各种loss function 工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~


觉得有用麻烦给个在看啦~   
登录查看更多
0

相关内容

IM:IFIP/IEEE International Symposium on Integrated Network Management。 Explanation:综合网络管理国际研讨会。 Publisher:IFIP/IEEE SIT: http://dblp.uni-trier.de/db/conf/im/index.html
算法通关手册(LeetCode)
专知会员服务
156+阅读 · 2022年1月13日
专知会员服务
32+阅读 · 2021年10月4日
专知会员服务
90+阅读 · 2020年12月26日
专知会员服务
19+阅读 · 2020年12月9日
专知会员服务
16+阅读 · 2020年10月18日
专知会员服务
15+阅读 · 2020年7月27日
【高能所】如何做好⼀份学术报告& 简单介绍LaTeX 的使用
【ICLR-2020】网络反卷积,NETWORK DECONVOLUTION
专知会员服务
37+阅读 · 2020年2月21日
pytorch提取参数及自定义初始化
极市平台
0+阅读 · 2022年4月13日
pytorch学习 | 提取参数及自定义初始化
极市平台
0+阅读 · 2021年12月21日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
PyTorch 深度剖析:如何保存和加载PyTorch模型?
极市平台
0+阅读 · 2021年11月28日
深度神经网络模型训练中的最新tricks总结【原理与代码汇总】
人工智能前沿讲习班
172+阅读 · 2019年3月6日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
Arxiv
0+阅读 · 2022年4月20日
Arxiv
0+阅读 · 2022年4月18日
Arxiv
0+阅读 · 2022年4月18日
Arxiv
12+阅读 · 2019年4月9日
VIP会员
相关VIP内容
算法通关手册(LeetCode)
专知会员服务
156+阅读 · 2022年1月13日
专知会员服务
32+阅读 · 2021年10月4日
专知会员服务
90+阅读 · 2020年12月26日
专知会员服务
19+阅读 · 2020年12月9日
专知会员服务
16+阅读 · 2020年10月18日
专知会员服务
15+阅读 · 2020年7月27日
【高能所】如何做好⼀份学术报告& 简单介绍LaTeX 的使用
【ICLR-2020】网络反卷积,NETWORK DECONVOLUTION
专知会员服务
37+阅读 · 2020年2月21日
相关资讯
pytorch提取参数及自定义初始化
极市平台
0+阅读 · 2022年4月13日
pytorch学习 | 提取参数及自定义初始化
极市平台
0+阅读 · 2021年12月21日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
PyTorch 深度剖析:如何保存和加载PyTorch模型?
极市平台
0+阅读 · 2021年11月28日
深度神经网络模型训练中的最新tricks总结【原理与代码汇总】
人工智能前沿讲习班
172+阅读 · 2019年3月6日
相关基金
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
Top
微信扫码咨询专知VIP会员