记录模型训练时loss值的变化情况
作者:暴躁的猴子 发布时间:2022-03-22 23:39:48
记录训练过程中的每一步的loss变化
if verbose and step % verbose == 0:
sys.stdout.write('\r{} / {} : loss = {}'.format(
step, total_steps, np.mean(total_loss)))
sys.stdout.flush()
if verbose:
sys.stdout.write('\r')
sys.stdout.flush()
一般我们在训练神经网络模型的时候,都是每隔多少步,输出打印一下loss或者每一步打印一下loss,今天发现了另一种记录loss变化的方法,就是用
sys.stdout.write('\r{} / {} : loss = {}')
如图上的代码,可以记录每一个在每个epoch中记录用一行输出就可以记录每个step的loss值变化,
\r就是输出不会换行,因此如果你想同一样输出多次,在需要输出的字符串对象里面加上"\r",就可以回到行首了。
sys.stdout.flush() #一秒输出了一个数字
具体的实现就是下面的图:
这样在每个epoch中也可以观察loss变化,但是只需要打印一行,而不是每一行都输出。
补充知识:训练模型中损失(loss)异常分析
前言
训练模型过程中随时都要注意目标函数值(loss)的大小变化。一个正常的模型loss应该随训练轮数(epoch)的增加而缓慢下降,然后趋于稳定。虽然在模型训练的初始阶段,loss有可能会出现大幅度震荡变化,但是只要数据量充分,模型正确,训练的轮数足够长,模型最终会达到收敛状态,接近最优值或者找到了某个局部最优值。在模型实际训练过程中,可能会得到一些异常loss值,如loss等于nan;loss值忽大忽小,不收敛等。
下面根据自己使用Pythorh训练模型的经验,分析出一些具体原因和给出对应的解决办法。
一、输入数据
1. 数据的预处理
输入到模型的数据一般都是经过了预处理的,如用pandas先进行数据处理,尤其要注意空值,缺失值,异常值。
缺失值:数值类型(NaN),对象类型(None, NaN),时间类型(NaT)
空值:""
异常值:不再正常区间范围的值
例如对缺失值可以进行判断df.isnull()或者df.isna();丢弃df.dropna();填充df.fillna()等操作。
输入到模型中的数据一般而言都是数值类型的值,一定要保证不能出现NaN, numpy中的nan是一种特殊的float,该值数值运算的结果是不正常的,所以可能会导致loss值等于nan。可以用numpy.any(numpy.isnan(x))检查一下input和target。
2. 数据的读写
例如使用Pandas读取.csv类型的数据得到的DataFrame会添加默认的index,再写回到磁盘会多一列。如果用其他读取方式再读入,可能会导致数据有问题,读取到NaN。
import pandas as pd
Output = pd.read_csv('./data/diabetes/Output.csv')
trainOutput, testOutput = Output[:6000], Output[6000:]
trainOutput.to_csv('./data/diabetes/trainOutput.csv')
testOutput.to_csv('./data/diabetes/testOutput.csv')
3. 数据的格式
Pythorch中的 torch.utils.data.Dataset 类是一个表示数据集的抽象类。自己数据集的类应该继承自 Dataset 并且重写__len__方法和__getitem__方法:
__len__ : len(dataset) 返回数据集的大小
__getitem__ :用以支持索引操作, dataset[idx]能够返回第idx个样本数据
然后使用torch.utils.data.DataLoader 这个迭代器(iterator)来遍历所有的特征。具体可以参见这里
在构造自己Dataset类时,需要注意返回的数据格式和类型,一般不会出现NaN的情况但是可能会导致数据float, int, long这几种类型的不兼容,注意转换。
二、学习率
基于梯度下降的优化方法,当学习率太高时会导致loss值不收敛,太低则下降缓慢。需要对学习率等超参数进行调参如使用网格搜索,随机搜索等。
三、除零错
对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决。类似于计算概率时进行的平滑修正,下面的代码片段中loss使用交叉混合熵(CossEntropy),计算3分类问题的AUC值,为了避免概率计算出现NaN而采取了相应的平滑处理。
from sklearn.metrics import roc_auc_score
model_ft, y_true, losslists = test_model(model_ft, criterion, optimizer)
n_class = 3
y_one_hot = np.eye(n_class)[y_true.reshape(-1)]
# solve divide zero errot
eps = 0.0000001
y_scores = losslists / (losslists.sum(axis=1, keepdims=True)+eps)
#print(y_scores)
#print(np.isnan(y_scores))
"""
metrics.roc_auc_score(y_one_hot, y_pred)
"""
print("auc: ")
roc_auc_score(y_one_hot, y_scores)
四、loss函数
loss函数代码编写不正确或者已经编写好的loss函数API使用不清楚
五、某些易错代码
Pytorch在进行自动微分的时候,默认梯度是会累加的,所以需要在每个epoch的每个batch中对梯度清零,否则可能会导致loss值不收敛。不要忘记添加如下代码
optimizer.zero_grad()
来源:https://blog.csdn.net/orangefly0214/article/details/83064171
猜你喜欢
- xhEditor简介xhEditor是一个基于jQuery开发的简单迷你并且高效的可视化HTML编辑器,基于网络访问并且兼容IE 6.0+,
- 本文实例讲述了python实现基于两张图片生成圆角图标效果的方法。分享给大家供大家参考。具体分析如下:使用pil的蒙版功能,将原图片和圆角图
- ASP开发中有用的函数(function)集合,挺有用的,请大家保留!'******************************
- __import__() 函数用于动态加载类和函数 。如果一个模块经常变化就可以使用 __import__() 来动态载入。语法__impo
- python用正则表达式提取中文Python re正则匹配中文,其实非常简单,把中文的unicode字符串转换成utf-8格式就可以了,然后
- 目录准备数据集导入所需的软件包将数据从文件加载到Python变量拆分数据进行训练和测试标记化并准备词汇预处理输出标签/类建立Keras模型并
- 背景一直以来,中式占卜都是基于算命先生手工实现,程序繁琐(往往需要沐浴、计算天时、静心等等流程)。准备工作复杂(通常需要铜钱等道具),计算方
- TKinter库,Python 的 GUI 库非常多,之所以选择 Tkinter,一是最为简单,二是自带库,不需下载安装,随时使用,跨平台兼
- 这篇文章主要介绍了Python如何实现强制数据类型转换,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 第一步,下载PHPphp官网地址windows 下载直接解压即可liunx请自行csdn搜索教程第二步,下载code插件1. PHP Deb
- 我们直接先给出输出与预期不同的代码In[28]: a = [1,2,3,4,5,6]In[29]: for i in a: ...: &nb
- 废话不多说,直接上代码吧!import threadingimport osclass Find(threading.Thread): #搜
- 这两副图片哪张更能勾起你买东西的欲望呢?相信大多数买家更喜欢看大图,实物图,产品细节图等.如果我们的卖家更能倾听下我们买家的心声.他们的产品
- 1. 折线图折线图(Line Chart)是一种将数据点按照顺序连接起来的图形,也可以看作是将散点图按照X轴坐标顺序链接起来的图形。折线图的
- 序言这次玩次狠得。除了编译器使用yum安装,其他全部手动编译。哼~看似就Nginx、PHP、MySql三个东东,但是它们太尼玛依赖别人了。没
- PHP get_html_translation_table() 函数实例输出 htmlspecialchars 函数使用的翻译表:<
- 1. 像素基本操作1.1 读取、修改像素可以通过[行,列]坐标来访问像素点数据,对于多通道数据,返回一个数组,包含所有通道的值,对于单通道数
- 换脸!这段时间,deepfakes搞得火热,比方说把《射雕英雄传》里的朱茵换成了杨幂,看下面的图!毫无违和感!其实早在之前,基于AI换脸的技
- 大家好,我是安果!最近在部署前端项目的时候,需要先将前端项目压缩包通过堡垒机上传到应用服务器的 /tmp 目录下,然后进入应用服务器中,使用
- 今天请各位读者朋友欣赏用 Python 实现的鲜花盛宴,你准备好了吗?90 行代码即可实现一棵美丽的鲜花盛开树。小编也是鲜花爱护协会者之一,