基于采样点的图像伪解析类别增量学习方法及装置

本发明属于图像处理及增量学习的,具体涉及一种基于采样点的图像伪解析类别增量学习方法及装置。
背景技术:
1、当前主流的机器学习模型在处理特定任务时已经取得了显著的成效,但其成功的基础往往依赖于人类对特定问题的经验知识。然而,面对现实世界中复杂多变的任务,这些模型往往显得力不从心。传统机器学习模型通常是在固定数据集上进行训练,旨在解决单一任务;但由于现实应用中的数据是连续流动和不断变化的,因此模型需要应对新的数据和任务,然而传统机器学习模型一旦部署,就很难适应新的环境和数据,这限制了其在实际应用中的灵活性和有效性。为了弥补这一不足,增量学习(incremental learning)应运而生,其目标是在不断变化的数据环境中,使得模型能够持续学习和优化,确保在面对新数据时依然表现优异。增量学习的发展历程可以追溯到神经网络和机器学习领域的早期研究。早期的研究主要集中在如何通过增量式的训练方法,提高模型的学习效率和性能;随着深度学习技术的发展,增量学习逐渐成为一个独立且重要的研究方向,其不仅是机器学习领域的重要研究方向,也是实现更高层次人工智能的重要路径,有望在无人驾驶、智能机器人、医疗诊断等领域带来革命性的进展。
2、增量学习(incremental learning)的核心理念是模仿人类大脑的学习过程,通过不断积累和整合新的知识,逐步提高解决问题的能力。增量学习的一个重要特性是能够在处理新任务时保留已学知识,这与传统的机器学习方法形成鲜明对比。传统机器学习方法在引入新任务时往往会出现灾难性遗忘(catastrophic forgetting)现象,即模型在学习新任务时,会丢失之前所学知识,尤其是在使用反向传播学习的网络。而增量学习通过优化算法和模型结构,努力克服这一问题,实现对新旧知识的有效整合。早期缓解灾难性遗忘的方法通常是存储以前的数据,并定期重放从旧样本与新样本中提取的样本交错数据(shinh,lee j k,kim j,et al.continual learning with deep generative replay[j].2017.doi:10.48550/arxiv.1705.08690.)。然而,缺点是需要显式存储旧信息,导致工作内存需求大;此外对于固定数量的神经资源,还需设计专门的机制保护综合知识不被新信息的学习所覆盖。另外灾难性遗忘还可以通过在需要时分配额外的神经资源进行缓解,但这种方法可能会导致可伸缩性问题,从而显著增加神经网络的计算工作量进而变成非常大的架构。现有技术中,研究者们提出了多种增量学习方法,包括基于回放(replay-based)的方法、基于正则化(regularization-based)的方法、以及基于参数隔离(parameterisolation-based)的方法。这些方法在解决灾难性遗忘问题上取得了一定的进展,但仍存在许多挑战和局限。
3、对于基于回放的增量学习方法(如《parisi g i,kemker r,part j l,etal.continual lifelong learning with neural networks:a review[j].neuralnetworks:the official journal of the international neural network society,2019(113-):113.》):该方法在增量学习时回放存储的旧任务数据以克服灾难性遗忘,比如记忆回放,这种方法取得了不错的效果,但是由于其需要存储旧数据,导致内存占用大,而且存储的旧数据存在泄露用户隐私的安全问题。对于基于正则化的增量学习方法(如《french r m.catastrophic forgetting in connectionist networks[j].trends incognitive sciences,1999.doi:10.1016/s1364-6613(99)01294-2.》):在损失函数中添加额外约束的正则化方法,该方法使得对旧知识较重要的权值在学习新数据的过程中不会过度变化,从而一定程度上克服了灾难性遗忘,如lwf、ewc等方法,并且这种方法不存储旧数据,安全性强;但是该方法需要防止旧任务重要参数在新任务训练中的过度偏移现象,因此如何定量判别权值的重要性是一个问题。对于基于参数隔离的增量学习方法(如《ratcliffr.connectionist models of recognition memory:constraints imposed by learningand forgetting functions.[j].psychological review,1990,97(2):285-308.doi:10.1037//0033-295x.97.2.285.》):该方法根据需求在增量学习过程中通过冻结/解冻、添加/删除模型参数等手段,灵活配置模型参数以达到克服灾难性遗忘的效果,但是该方法比较依赖生成器的性能,并且适用性较差。
4、近年来有学者提出了解析类别增量学习(analytical class-incrementallearning,acil)方法,该方法利用多元线性函数最小二乘法的性质,通过矩阵的递推公式计算增量学习的解析解,从而实现类别增量学习,不仅保证了隐私性,而且取得了令人惊奇的良好性能,除此之外,解析类别增量学习还具有良好的理论可解释性。然而该方法仅在模型的最后一层线性层进行增量学习,这显然限制了该类方法的性能;并且随着增量学习的持续学习,复杂度增加,导致误差累积问题变得显著;同时该方法仅适用于线性模型的增量学习,无法推广至非线性模型,进一步限制了其适用性。
技术实现思路
1、本发明的主要目的在于克服现有技术的缺点与不足,提供一种基于采样点的图像伪解析类别增量学习方法及装置,在解析类别增量学习方法的基础上,使用采样点法进行改进,不仅可以处理线性模型,还可以在训练过程中有效地处理非线性模型的复杂性,并在不同的训练阶段保持模型的稳定性和精度。
2、为了达到上述目的,本发明采用以下技术方案:
3、本发明第一目的在于,提供一种基于采样点的图像伪解析类别增量学习方法,包括下述步骤:
4、s1、构建图像分类网络,包含背景网络及解析分类器;所述背景网络通过卷积神经网络提取图像特征,所述解析分类器使用全连接网络对图像特征进行分类,以独热编码形式输出分类结果;所述解析分类器基于线性模型或非线性模型进行构建;
5、s2、使用反向梯度下降算法在基础数据集上对后接有softmax分类器的背景网络进行预训练,结束后冻结背景网络的参数并得到基础数据集的一维化特征;
6、s3、去除softmax分类器,在冻结参数的背景网络后接入解析分类器,并在解析分类器中增加一个虚拟节点参与后续操作;对应的分类结果增加一个维度,该维度为对应虚拟节点的始终为0的虚拟标签;
7、s4、对基础数据集的一维化特征进行随机维度拓展,输入解析分类器使用反向梯度下降算法进行重对齐,结束后计算缩放系数并得到初始采样点集;所述初始采样点集中采样点的数量由解析分类器确定并保持恒定;
8、s5、获取当前阶段的增量数据集,输入冻结参数的背景网络中获取新增特征集,与初始采样点集进行混合得到混合数据集;
9、s6、采用mse损失函数通过反向梯度下降算法在混合数据集上对上一阶段的解析分类器进行增量学习,并更新采样点作为下一阶段的初始采样点集;
10、s7、获取下一阶段的增量数据集,重复步骤s5-s6,直至达到设定阶段或无下一阶段的增量数据集,得到学习好的图像分类器。
11、作为优选的技术方案,所述步骤s2中,所述对后接有softmax分类器的背景网络进行预训练过程为:
12、将基础数据集输入背景网络中进行特征提取,得到基础数据特征集;
13、对基础数据特征集进行特征一维扁平化,得到基础数据集的一维化特征;
14、后接的softmax分类器对基础数据集的一维化特征进行分类,以独热编码形式输出分类结果,表示为:z=fsoftmax(fflat(fcnn(x0,wcnn),wfcn));其中,fsoftmax为softmax分类器,fflat为一维压平操作,fcnn为背景网络,x0为基础数据集,wcnn为背景网络参数,wfcn为一维压平操作中线性输出层的参数,将一维化特征的维度调整为与基础数据集真实标签相同的维度;
15、基于分类结果与基础数据集的真实标签计算损失函数,通过反向梯度下降算法更新网络参数,直至损失函数收敛或达到最大迭代次数,冻结背景网络的参数。
16、作为优选的技术方案,步骤s4中,所述输入解析分类器使用反向梯度下降算法进行重对齐,步骤为:
17、使用激活函数对基础数据集的一维化特征进行随机维度拓展,得到拓展后的特征集
18、
19、其中,fact为激活函数,fflat为一维压平操作,fcnn为背景网络,x0为基础数据集,wcnn为背景网络参数,wfcn为一维压平操作中线性输出层的参数,将一维化特征的维度调整为与基础数据集真实标签相同的维度,wfe为随机生成的参数拓展矩阵;
20、将拓展后的特征集输入解析分类器使用mse损失函数通过反向梯度下降算法进行重对齐;重对齐过程中,所述解析分类器参数的优化函数为:
21、
22、其中,wil为解析分类器的参数,y0为基础数据集的真实标签集,fil为解析分类器,表示基础数据集的预测标签集,||·||f为frobenius范数;
23、完成后,随机生成一组设定数量的采样点及其对应的预测标签,基于拓展后的特征集及预测标签集对采样点及其对应的预测标签进行优化,计算缩放系数并得到初始采样点集。
24、作为优选的技术方案,所述初始采样点集获取步骤为:
25、随机生成一组设定数量的采样点及其对应的预测标签;所述采样点的坐标及其对应的预测标签随机生成;
26、基于采样点的设定数量和基础数据集的数量,计算缩放系数λ0,公式为:
27、
28、其中,nsam为采样点的设定数量,n0为基础数据集的数量;
29、计算mse损失函数在基础数据集上对训练后解析分类器参数的一阶梯度向量g0和二阶梯度矩阵g0,以及mse损失函数在采样点上对训练后解析分类器参数的一阶梯度向量和二阶梯度矩阵
30、使用梯度下降算法分别优化采样点坐标及其对应的预测标签,使得二阶梯度矩阵近似相等于λ0g0且一阶梯度向量近似相等于λ0g0,得到初始采样点集。
31、作为优选的技术方案,步骤s6中,所述在混合数据集上对上一阶段的解析分类器进行增量学习,具体为:
32、基于当前阶段增量数据集中的新增标签数量,将构建的虚拟节点及其在当前阶段对应的解析分类器参数进行复制,将最后一个复制的虚拟节点作为新的虚拟节点,其余复制的虚拟节点作为真实节点与当前阶段解析分类器的真实节点参与后续增量学习;所述复制次数等于新增标签数量;所述新的虚拟节点对应的虚拟标签始终为0;
33、在混合数据集上对上一阶段的解析分类器进行训练,得到当前阶段训练好的解析分类器;
34、重新生成一组设定数量的采样点及其对应的预测标签;基于采样点的设定数量和解析分类器历史累计的数据量,计算当前阶段的缩放系数;
35、分别计算mse损失函数在当前阶段的混合数据集以及重新生成的采样点上对当前阶段训练好的解析分类器参数的一阶梯度向量和二阶梯度矩阵,对重新生成的采样点进行梯度对齐和误差矫正作为下一阶段的初始采样点集。
36、作为优选的技术方案,训练过程中,所述解析分类器参数的优化函数为:
37、
38、其中,wil为解析分类器的参数,yk为当前第k阶段增量数据集的真实标签集,为当前第k阶段增量数据集的新增特征集,fil为解析分类器,表示当前第k阶段增量数据集的预测标签集,||·||f为frobenius范数,λk-1为初始采样点集的缩放系数,为初始采样点集损失函数的权值,为初始采样点集,为初始采样点集的真实标签集,为初始采样点集的预测标签集;
39、所述缩放系数的计算公式为:
40、
41、其中,λk为当前第k阶段的缩放系数,nsam为采样点的设定数量,nk为解析分类器历史累计的数据量。
42、作为优选的技术方案,所述对重新生成的采样点进行梯度对齐和误差矫正作为下一阶段的初始采样点集,具体为:
43、计算mse损失函数在当前阶段的混合数据集上对当前阶段训练好的解析分类器参数的一阶梯度向量gk和二阶梯度矩阵gk;
44、计算mse损失函数在当前阶段重新生成的采样点上对当前阶段训练好的解析分类器参数的一阶梯度向量和二阶梯度矩阵
45、使用梯度下降算法优化采样点坐标,使得二阶梯度矩阵近似相等于λkgk完成梯度对齐;
46、将优化后的采样点输入当前阶段训练好的解析分类器得到重新生成采样点的预测标签,并使用梯度下降算法优化采样点预测标签坐标,使得一阶梯度向量近似相等于λkgk完成误差矫正;
47、重复上述步骤直至所有重新生成采样点完成优化,优化后的采样点集作为下一阶段的初始采样点集。
48、本发明第二目的在于提供一种基于采样点的图像伪解析类别增量学习系统,包括分类网络构建模块、背景网络预训练模块、虚拟节点增加模块、解析分类器预训练模块、数据混合模块、解析分类器训练模块以及增量学习终止模块;
49、所述分类网络构建模块用于构建图像分类网络,包含背景网络及解析分类器;所述背景网络通过卷积神经网络提取图像特征,所述解析分类器使用全连接网络对图像特征进行分类,以独热编码形式输出分类结果;所述解析分类器基于线性模型或非线性模型进行构建;
50、所述背景网络预训练模块用于使用反向梯度下降算法在基础数据集上对后接有softmax分类器的背景网络进行预训练,结束后冻结背景网络的参数并得到基础数据集的一维化特征;
51、所述虚拟节点增加模块用于去除softmax分类器,在冻结参数的背景网络后接入解析分类器,并在解析分类器中增加一个虚拟节点参与后续操作;对应的分类结果增加一个维度,该维度为对应虚拟节点的始终为0的虚拟标签;
52、所述解析分类器预训练模块用于对基础数据集的一维化特征进行随机维度拓展,输入解析分类器使用反向梯度下降算法进行重对齐,结束后计算缩放系数并得到初始采样点集;所述初始采样点集中采样点的数量由解析分类器确定并保持恒定;
53、所述数据混合模块用于获取当前阶段的增量数据集,输入冻结参数的背景网络中获取新增特征集,与初始采样点集进行混合得到混合特征集;
54、所述解析分类器训练模块用于采用mse损失函数通过反向梯度下降算法在混合特征集上对上一阶段的解析分类器进行增量学习,并更新采样点作为下一阶段的初始采样点集;
55、所述增量学习终止模块用于获取下一阶段的增量数据集,重复执行数据混合模块和解析分类器训练模块,直至达到设定阶段或无下一阶段的增量数据集,得到学习好的图像分类器。
56、本发明第三目的在于提供一种电子设备,包括:
57、至少一个处理器;以及,与所述至少一个处理器通信连接的存储器;其中,
58、所述存储器存储有可被所述至少一个处理器执行的计算机程序指令,所述计算机程序指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行上述的基于采样点的图像伪解析类别增量学习方法。
59、本发明第四目的在于提供一种计算机可读存储介质,存储有程序,当程序被处理器执行时,实现上述的基于采样点的图像伪解析类别增量学习方法。
60、本发明与现有技术相比,具有如下优点和有益效果:
61、本技术方法使用采样点法对解析类别增量学习方法进行改进,通过对背景网络预训练冻结其参数,减少了背景网络参与后续训练的计算量;接着进行梯度重对齐构建初始采样点集,保留解析分类器在基础数据集上学习到的知识;然后开始增量学习,更新采样点保留每一阶段增量学习后的知识;本技术与其它增量学习方法相比,初始采样点集中存储了旧训练数据的知识,因此无需访问旧训练数据即可完成后续增量学习,具有隐私良好的安全性;同时存储的采样点也无法通过任何逆向工程还原出任何训练数据,仅是满足增量学习的数学性质,能够很大程度上减轻灾难性遗忘;并且通过采样点法保存学习到的知识,具有接近解析级的精度且可通用于非线性模型。
技术研发人员:程宇轩,庄辉平
技术所有人:华南理工大学
备 注:该技术已申请专利,仅供学习研究,如用于商业用途,请联系技术所有人。
声 明 :此信息收集于网络,如果你是此专利的发明人不想本网站收录此信息请联系我们,我们会在第一时间删除
