当前位置: 技术文章>> 100道python面试题之-在使用PyTorch进行深度学习训练时,如何实施早停(Early Stopping)策略?

文章标题:100道python面试题之-在使用PyTorch进行深度学习训练时,如何实施早停(Early Stopping)策略?
  • 文章分类: 后端
  • 7646 阅读

在PyTorch中实施早停(Early Stopping)策略是一种常用的技巧,用于防止模型在训练集上过拟合,同时在验证集上性能不再提升时提前停止训练。早停通常通过监控模型在验证集上的性能(如准确率、损失等)来实现,一旦验证集上的性能在一定轮次内没有改进,则停止训练。

以下是一个简单的早停策略实现步骤,使用PyTorch的torch.utils.data.DataLoader来加载数据,并使用自定义的早停类来管理训练过程:

  1. 定义早停类:首先,定义一个早停类,该类中包括一个计数器来记录连续多少次验证集上的性能没有改善,以及性能改善的阈值、最大训练轮次等参数。

  2. 训练循环:在训练循环中,每次迭代后都计算验证集上的性能指标,并与之前的最佳性能进行比较。

  3. 性能判断:如果当前性能比之前的最佳性能有所提升,则更新最佳性能,并重置计数器。如果当前性能没有提升,则增加计数器。

  4. 停止条件:如果计数器达到设定的阈值,则提前停止训练。

下面是一个简化的代码示例:

import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pth.tar'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pth.tar'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss  # 假设我们监控的是损失,我们希望它尽可能小

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# 假设你已经定义了模型、优化器、损失函数、训练集和验证集
# ...

# 实例化早停类
early_stopping = EarlyStopping(patience=10, verbose=True)

# 训练循环
for epoch in range(num_epochs):
    # 训练模型
    # ...

    # 验证模型
    val_loss = validate(model, val_loader)  # 假设validate函数计算并返回验证集上的损失

    # 检查是否需要早停
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break

注意,这个示例假设你正在最小化验证集上的损失。如果你的目标是最大化验证集上的某个性能指标(如准确率),则需要对代码进行适当调整。

推荐文章