torch.nn.utils.clip_grad_norm_
函数在 PyTorch 中是一个用于梯度裁剪的工具,其主要目的是控制梯度的范数以防止梯度爆炸问题,这在训练深度神经网络时是一个常见的问题。梯度裁剪通过限制梯度的最大范数来帮助稳定训练过程,尤其是在使用大学习率或深度网络时。
函数签名
torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)
- parameters (Iterable[Tensor] or Tensor): 需要裁剪的参数的迭代器或单个参数。这通常是模型的参数,即
model.parameters()
。 - max_norm (float): 裁剪的范数阈值。如果所有参数的梯度的总范数大于这个值,则梯度会被等比例缩放,使得总范数等于这个值。
- norm_type (float, 可选): 用于计算范数的类型。默认为2,即L2范数(欧几里得范数)。也可以是其他类型,如L1范数(
norm_type=1
)。
工作原理
计算梯度范数:首先,根据指定的
norm_type
(默认为2,即L2范数),计算所有给定参数梯度的总范数。裁剪梯度:如果计算出的总范数大于
max_norm
,则按比例缩放每个参数的梯度,使得缩放后的总范数等于max_norm
。这通常是通过将每个梯度分量除以总范数与max_norm
的比例来实现的。原地修改:注意函数名中的下划线(
_
),这表示该函数会原地(in-place)修改参数的梯度,即直接修改传入参数的.grad
属性,而不是返回新的梯度张量。
示例
import torch
import torch.nn as nn
# 假设我们有一个简单的网络
model = nn.Linear(10, 5)
# 假设我们有一些输入和标签,并进行了前向传播和反向传播
input = torch.randn(3, 10)
target = torch.randn(3, 5)
criterion = nn.MSELoss()
output = model(input)
loss = criterion(output, target)
loss.backward()
# 在梯度反向传播后,我们裁剪梯度
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 现在,model.parameters() 中的每个参数的.grad都已经被裁剪过了
重要性
梯度裁剪是训练稳定神经网络的一种重要技术,特别是在处理梯度爆炸问题时。通过限制梯度的最大范数,可以防止参数更新过大,从而有助于模型的收敛。