首页  专利技术  电子电路装置的制造及其应用技术

双教师类脑蒸馏持续学习图像分类方法、设备及存储介质

2026-02-03 15:40:01 150次浏览
双教师类脑蒸馏持续学习图像分类方法、设备及存储介质

本发明涉及一种图像分类方法,特别涉及一种双教师类脑蒸馏持续学习图像分类方法、设备及存储介质。


背景技术:

1、目前,持续学习指的是一类机器学习任务,其中模型需要连续学习多个任务,模型需要在无旧任务数据(或严格限制其数量)参与联合训练的情况下维持旧任务信息并连续不断地学习新任务。知识蒸馏指的是一类机器学习方法,包含学生模型和教师模型,蒸馏方法通过比较模型对同一输入的响应,使学生模型的响应接近教师模型从而学习其知识。根据教师模型的数量分为单教师蒸馏和多教师蒸馏。

2、近年来,随着transformer架构的提出和预训练模型在视觉领域中的大规模使用,持续学习的图像分类技术取得了巨大的进展。更进一步地,人们在视觉预训练模型的基础上提出了视觉语言预训练模型。然而,预训练视觉语言模型的持续学习在训练过程中除了面临灾难性遗忘外,还面临着零样本迁移能力的遗忘。普通的持续学习方法很难阻止零样本迁移能力的遗忘,而蒸馏方法和基于预训练视觉语言模型的持续学习相结合是一种解决方法。

3、目前已有的方法主要通过利用辅助数据集实现旧模型和当前模型之间的蒸馏,这种方式不仅需要额外的数据集存储空间,这在某些场景下也会受到限制。另外,现有方法采用单教师蒸馏范式,仅使用初始模型或前一任务的模型进行蒸馏,因此在缓解灾难性遗忘和零样本迁移能力遗忘的问题上存在局限。


技术实现思路

1、本发明为解决公知技术中存在的技术问题而提供一种双教师类脑蒸馏持续学习图像分类方法、设备及存储介质。

2、本发明为解决公知技术中存在的技术问题所采取的技术方案是:

3、一种双教师类脑蒸馏持续学习图像分类方法,该方法包括如下步骤:

4、步骤1,编制训练用图像样本集;将图像样本集按照任务类别分割为多个任务图像样本集,每个任务图像样本集分成训练数据集和测试数据集;每个任务的训练数据集包括多个类别图像样本集;每个类别图像样本集包括多张带有标签的图像样本以及对应的标签和类别语义信息;

5、步骤2,构建视觉语言预训练模型,用于从图像数据中得到初始视觉原型;视觉语言预训练模型包括视觉编码器和文本编码器;视觉编码器用于处理图像数据,将图像编码成向量数据;文本编码器用于处理文本数据,将文本编码成向量数据;利用视觉语言预训练模型对各任务中的训练数据集中所有图像样本提取特征;通过计算同一任务中同一类别图像样本的特征平均值,得到每个类别的初始视觉原型;

6、步骤3,将视觉语言预训练模型定义为e0;将e0作为第一教师模型;构建学生模型,使学生模型的结构与视觉语言预训练模型相同,并加载视觉语言预训练模型权重参数;

7、步骤4,对学生模型进行第m个任务的训练,m=1、2、...、t,t为任务个数,训练后的学生模型记为em;对学生模型进行第m+1个任务训练时,em作为第二教师模型,学生模型从em加载权重进行训练;使视觉原型在训练过程中随着批次数据进行滑动更新;

8、步骤5,根据em计算分类损失lce和实例原型对齐损失lcon,根据em-1和e0计算实例原型相似度蒸馏损失,将其平均值并作为最终蒸馏损失ldis;将分类损失lce、实例原型对齐损失lcon、最终蒸馏损失ldis进行加权组合得到优化目标;

9、步骤6,重复步骤4至步骤5,直至所有训练任务的优化目标均达到设定值,训练结束;

10、步骤7,采用完成训练的模型对待分类图像进行分类。

11、进一步地,步骤5中,优化目标的计算公式如下:

12、l=lce+0.2lcon+0.2ldis;

13、式中:

14、l为优化目标;

15、lce为分类损失;

16、lcon为实例原型对齐损失;

17、ldis为最终蒸馏损失。

18、进一步地,步骤2中,每个类别的视觉原型的计算公式如下:

19、

20、式中:

21、c为样本类别序号;

22、为视觉语言预训练模型对第c类样本第j张图像提取的特征;

23、为第c类样本的初始视觉原型;

24、k为第c类样本的样本数量;

25、norm()表示归一化函数。

26、进一步地,步骤4中,使视觉原型在训练过程中随着批次数据进行滑动更新的方法包括如下方法步骤:

27、令滑动更新计算公式如下:

28、

29、式中:

30、为对第c类样本的视觉原型进行第t次滑动更新后的视觉原型;

31、为对第c类样本的视觉原型进行第t+1次滑动更新后的视觉原型;

32、t为滑动次数;

33、c为样本类别序号;

34、为当前模型对该批次第j张图像提取的特征;

35、b为该批次的样本数量;

36、norm()表示归一化函数;

37、γ为滑动系数。

38、进一步地,步骤5中,计算实例原型对齐损失的方法包括如下方法步骤:

39、将第m个任务的第c类样本数据输入学生模型,得到第c类样本数据的视觉特征,利用第c类样本数据的视觉特征、文本特征以及滑动更新后的视觉原型,构建跨模态的实例原型对齐损失;实例原型对齐损失的计算公式如下:

40、

41、式中:

42、lcon为实例原型对齐损失;

43、c为样本类别序号;

44、m为任务序号;

45、为的集合;

46、为对第c类样本的视觉原型进行第t次滑动更新后的视觉原型;

47、hc为类别c的视觉文本特征矩阵;

48、hm为学生模型对应第m个任务输出的视觉文本特征矩阵;

49、pm为第m个任务的原型矩阵;

50、h为中的任一元素;

51、g为中除了h的任一元素;

52、f为hm∪pm中除了h的任一元素;

53、τ为温度系数;

54、sim为余弦相似度函数。

55、进一步地,步骤5中,最终蒸馏损失的计算公式如下:

56、设m为任务序号;当m大于等于2时:

57、

58、

59、

60、当m等于1时:

61、式中:

62、为根据e0计算的实例原型相似度蒸馏损失;

63、为根据ei-1计算的实例原型相似度蒸馏损失;

64、ldis为最终蒸馏损失;

65、()表示弗罗贝尼乌斯范数;

66、sim()表示余弦相似度函数;

67、hm为学生模型对应第m个任务输出的视觉文本特征矩阵;

68、为第一教师模型对应第m个任务输出的视觉文本特征矩阵;

69、为第二教师模型对应第m个任务输出的视觉文本特征矩阵;

70、pm为第m个任务的原型矩阵。

71、进一步地,步骤5中,采用交叉熵损失函数计算分类损失,分类损失的计算公式如下所示:

72、lce=ce(zm,ym);

73、式中:

74、m为任务序号;

75、lce为分类损失;

76、zm为学生模型对应第m个任务输出的logits矩阵;

77、ym为对应第m个任务的标签矩阵;

78、ce()表示交叉熵损失函数。

79、本发明还提供了一种双教师类脑蒸馏持续学习图像分类方法的设备,包括存储器和处理器,所述存储器用于存储计算机程序;所述处理器,用于执行所述计算机程序并在执行所述计算机程序时实现如上述的双教师类脑蒸馏持续学习图像分类方法步骤。

80、本发明还提供了一种存储介质,所述存储介质存储有计算机程序,该计算机程序被处理器执行时,实现如上述的双教师类脑蒸馏持续学习图像分类方法步骤。

81、本发明具有的优点和积极效果是:

82、(1)本发明在基于视觉语言预训练模型和知识蒸馏进行持续学习分类的方法中引入了原型,充分利用多模态信息;(2)本发明仅需要使用训练数据集和之前的模型权重,以一种类脑启发的方式对有限的数据进行蒸馏,不需要使用额外的数据集;(3)本发明采用两个教师模型继续双教师蒸馏,同时缓解灾难性遗忘和零样本迁移能力的遗忘。

文档序号 : 【 40125352 】

技术研发人员:冀中,张鸿盛
技术所有人:天津大学

备 注:该技术已申请专利,仅供学习研究,如用于商业用途,请联系技术所有人。
声 明此信息收集于网络,如果你是此专利的发明人不想本网站收录此信息请联系我们,我们会在第一时间删除
冀中张鸿盛天津大学
一种传感器时序数据异常检测方法及装置 一种保安型机械控流燃气表的制作方法
相关内容