Shortcuts

unike.utils.EarlyStopping 源代码

# coding:utf-8
#
# unike/utils/EarlyStopping.py
#
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 5, 2024
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
#
# 该脚本定义了 EarlyStopping 类.

"""
EarlyStopping - 使用早停止避免过拟合。
"""

import os
import numpy as np
from ..module.model import Model
from loguru import logger

[文档]class EarlyStopping: """ 如果验证得分(越大越好)在给定的耐心后没有改善,则提前停止训练。 """
[文档] def __init__( self, save_path: str, patience: int = 2, verbose: bool = True, delta: float = 0): """创建 EarlyStopping 对象。 :param save_path: 模型保存目录 :type save_path: str :param patience: 上次验证得分改善后等待多长时间。默认值:2 :type patience: int :param verbose: 如果为 True,则为每个验证得分改进打印一条消息。默认值:True :type verbose: bool :param delta: 监测数量的最小变化才符合改进条件。默认值:0 :type delta: float """ #: 模型保存目录 self.save_path: str = os.path.join(save_path, 'best_network.pth') #: 上次验证得分改善后等待多长时间。默认值:2 self.patience: int = patience #: 如果为 True,则为每个验证得分改进打印一条消息。默认值:True self.verbose: bool = verbose #: 监测数量的最小变化才符合改进条件。默认值:0 self.delta: float = delta #: 计数变量 self.counter: int = 0 #: 保存最好的得分 self.best_score: float = -np.Inf #: 早停开关 self.early_stop: bool = False
[文档] def __call__( self, score: float, model: Model): """ 进行早停记录。 """ if score <= self.best_score + self.delta: self.counter += 1 logger.info(f'EarlyStopping counter: {self.counter} / {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.save_checkpoint(score, model) self.counter = 0
[文档] def save_checkpoint( self, score: float, model: Model): """ 当验证得分改善时保存模型。 """ if self.verbose: logger.info(f'Validation score improved ({self.best_score:.6f} --> {score:.6f}). Saving model ...') model.save_checkpoint(self.save_path) self.best_score = score

Docs

Access comprehensive developer documentation for UniKE

View Docs