一种基于混合数据增强和开集半监督学习的超声心动图角度分类方法及系统

本发明属于医学图像分析,具体涉及一种基于混合数据增强和开集半监督学习的超声心动图角度分类方法及系统。
背景技术:
1、超声心动图被广泛地用于心脏疾病的诊断,在医学领域发挥着重要的作用。在诊断心脏疾病时,医生需要对特定角度的超声心动图进行分析诊断。然而,当前的基于监督学习的超声心动图角度分类方法需要大量的训练成本,超声心动图数据标签的大量标注,需要耗费心脏学专家大量的时间。有些分类方法尝试使用半监督学习方法,但这些方法难以达到训练成本和模型准确性的平衡,导致模型泛化性较差。混合数据增强技术(mixup dataaugmentation)通过线性插值的方法对数据集进行扩增,被广泛地应用于深度学习中,能够提高模型的泛化能力。然而,目前的数据增强方法难以捕获超声心动图中的有效信息。
2、开集半监督学习(open set semi-supervised learning)是一种新的半监督学习范式,用于解决半监督学习受类外分布数据损害的问题。近年来,开集半监督在训练过程中检测并排除类外分布数据的思想,被广泛地用于计算机视觉领域。然而现有的半监督分类方法存在两方面问题:(1)目前的开集半监督分类局限于使用复杂的计算流程去检测类外分布数据,算法流程复杂、效率低下;(2)目前的分类方法尚未关注到如何利用类外分布数据的信息,可能无法充分利用未标记数据集的特征信息,导致分类的性能仍然比较有限。
技术实现思路
1、本发明的目的是为解决现有方法难以捕获超声心动图中的有效信息、对类外分布数据的检测流程复杂且检测效率低、无法充分利用未标记数据集的特征信息的问题,而提出了一种基于混合数据增强和开集半监督学习的超声心动图角度分类方法及系统。
2、本发明为解决上述技术问题所采取的技术方案是:
3、一种基于混合数据增强和开集半监督学习的超声心动图角度分类方法,所述方法具体包括以下步骤:
4、步骤s1、分别获取带类别标签的超声心动图数据集以及不带类别标签的超声心动图数据集,所述类别标签为超声心动图拍摄的角度;
5、步骤s2、对步骤s1中获取到的各个超声心动图分别进行预处理,预处理后图像的尺寸为n×n;
6、步骤s3、分别构建教师模型和学生模型;
7、在教师模型的编码部分中,第三层与第四层之间添加有混合数据增强单元;在学生模型的编码部分中,添加开集半监督单元作为分类器的并行分支;
8、步骤s4、利用步骤s2中预处理后的图像对构建的教师模型和学生模型进行端到端的联合训练;
9、步骤s5、利用训练好的学生模型对待分类超声心动图进行角度分类。
10、进一步地,所述步骤s2的具体过程为:
11、对获取到的超声心动图进行随机水平翻转后,再进行任意大小以及任意方式的像素填充操作,最后通过随机剪裁获得各个尺寸为n×n的图像。
12、进一步地,所述混合数据增强单元包括语义特征图获取模块、特征令牌获取模块、语义特征引导模块、混合掩码学习模块和混合数据生成模块;
13、所述开集半监督单元包括分布内概率计算模块、分布外概率计算模块、超类伪标签生成模块和开集分类器。
14、进一步地,所述语义特征图获取模块采用的是wideresnet、fastvit、repvit或视觉transformer模型;
15、在前向传播的过程中,保存模型的倒数第二层输出的特征作为提取出的语义特征图。
16、进一步地,所述教师模型的训练过程为:
17、对于任一批次数据:
18、步骤s41、对于当前批次中任意一个不带类别标签的超声心动图,超声心动图输入教师模型后,利用语义特征图获取模块输出语义特征图h是语义特征图x的高度,w是语义特征图x的宽度,c是语义特征图x的通道数;
19、特征令牌获取模块将语义特征图x分割成s×s个子特征图,且每个子特征图的大小均为将获得的各个子特征图作为语义特征图x的各个特征令牌;
20、步骤s42、利用语义特征引导模块对步骤s41获得的各个特征令牌分别进行处理,生成加权特征图;
21、步骤s43、采用步骤s41和步骤s42的方法对每个不带类别标签的超声心动图分别进行处理,得到每个不带类别标签的超声心动图对应的加权特征图;
22、步骤s44、利用混合掩码学习模块对任意两个不带类别标签的超声心动图对应的加权特征图进行处理,得到混合掩码:
23、
24、其中,和是两个不带类别标签的超声心动图对应的加权特征图,q表示经过1×1的卷积层后的输出,k表示经过1×1的卷积层后的输出,d是语义特征图的维度,t代表转置,h为常数,sigmod为激活函数,upsample代表上采样操作,mask是混合掩码,代表矩阵乘法;
25、步骤s45、混合数据生成模块利用混合掩码来生成混合数据xm1x:
26、xmix=mask⊙x1+(1-mask)⊙x2
27、其中,x1表示对应的超声心动图,x2表示对应的超声心动图,⊙代表矩阵逐元素相乘;
28、步骤s46、当前批次中带类别标签的超声心动图数据和生成的混合数据经过教师模型后,根据教师模型的分类器输出来计算损失函数;
29、直至损失函数收敛时,停止当前批次的训练。
30、进一步地,所述步骤s42的具体过程为:
31、步骤s421、采用1×1的卷积对各个特征令牌分别进行线性变换后,将第i个特征令牌对应的键、查询和值分别记为和
32、步骤s422、分别对各个特征令牌的查询和键进行平均池化操作,得到平均池化后第i个特征令牌对应的查询和键
33、步骤s423、对于第i个特征令牌,分别计算第i个特征令牌平均池化后的query′i与其它任意一个特征令牌平均池化后的key′k的点积,i≠k,将计算出的各点积结果作为权重矩阵p中的第i行元素,
34、同理,对每个特征令牌分别进行处理后,得到权重矩阵p;
35、步骤s424、从权重矩阵p中检索出前k'个最大的元素值所对应的(key′k,query′i)对,再根据检索出的(key′k,wuery′i)对获得(key′k,value′i)对;
36、步骤s425、对步骤s424中获得的各个(key′k,value′i)对进行收集拼接,得到拼接结果key′和value′;
37、key′=gather(key′k)
38、value′=gather(value′i)
39、其中,key′是对各个(key′k,value′i)对中的key′k进行收集拼接的结果,value′是对各个(key′k,value′i)对中的value′i进行收集拼接的结果,gather(·)表示收集拼接张量的函数,i=1,2,…,k′,k=1,2,…,k′;
40、步骤s426、采用缩放点积注意力,计算出经过注意力得分加权后的特征图xout:
41、
42、其中,采用1×1的卷积对特征图x进行线性变换得到query,是缩放因子。
43、进一步地,所述步骤s46中的损失函数包括带类别标签超声心动图经过分类器的分类损失lce(c,z)和混合数据的生成损失lmix(x1,x2,xmix);
44、
45、其中,z=[p1,p2,p3,p4,...,pc’]为带类别标签超声心动图经过分类器输出的结果,pj表示超声心动图属于第j个类别的概率,c’表示已知类别的个数,c表示超声心动图的类别标签,pc表示分类器预测超声心动图属于类别标签c的概率;
46、
47、其中,cx1表示分类器输出的对x1的分类结果中,最大概率所对应的类别;cx2表示分类器输出的对x2的分类结果中,最大概率所对应的类别;pxmix表示分类器输出的对混合数据xmix的分类结果,pumix=[p′1,p′2,p′3,p′4,...,p′c’],p′1,p′2,p′3,p′4,...,p′c’表示混合数据xmix属于第1个,第2个,第3个,第4个,…,第c’个类别的概率。
48、进一步地,所述学生模型的训练过程为:
49、对于任一批次数据:
50、步骤1、将当前批次数据经过教师模型所获得的混合数据以及当前批次数据输入到学生模型;
51、步骤2、利用分布内概率计算模块输出每个不带标签数据的分布内概率以及带标签数据的分布内概率;
52、步骤3、利用分布外概率计算模块输出每个不带标签数据的分布外概率以及带标签数据的的分布外概率;
53、步骤4、对于任意一个不带标签的数据,利用数据的分布内概率和分布外概率计算开集伪标签;
54、步骤5、根据带标签数据的分布内概率以及带标签数据的的分布外概率计算开集的分类损失,根据步骤4获得的开集伪标签和混合数据经过开集分类器预测的分类结果计算伪标签的生成损失;
55、将开集分类损失和伪标签生成损失的和作为开集总损失,直至开集总损失收敛时,停止当前批次的训练。
56、进一步地,所述分布内概率为:
57、pin=linear(avgpool(relu(bn(x))))
58、其中,x是分布内概率计算模块的输入,pin是x的分布内概率,bn为批归一化层,relu为激活函数,avgpool为平均池化层,linear为全连接层;
59、所述分布外概率计算模块是由c’个二分类器组合成的多头分类器,其中,c’表示已知标签的类别个数;将x经过分布外概率计算模块的输出记为pout,分布外概率
60、利用分布内概率pin和分布外概率pout计算开集伪标签,具体为:
61、步骤(1)、计算分布内概率和分布外概率的点积psp:
62、psp=pinpout
63、步骤(2)、计算类内每个类别所占的权重:
64、
65、其中,zj是分布内概率pin中,x属于第j个类别的概率,αj是第j个类别所占的权重;
66、步骤(3)、计算x对应的开集伪标签:
67、sp=[α1psp[0],α2psp[0],…,αc′psp[0],psp[1]]
68、其中,psp[0]是psp中的第1个元素,psp[1]是psp中的第2个元素。
69、更进一步地,所述开集总损失包括伪标签的生成损失和开集的分类损失,即开集总损失l=lsp+lop;
70、开集的分类损失lsp为:
71、lsp=lin(c,z)+lout(pin,pout)
72、
73、其中,zb[c]表示带标签的数据经过分布内概率计算模块计算出的属于真实类别c的概率,zb[j]表示带标签的数据经过分布内概率计算模块计算出的属于类别j的概率,b表示每个批次中带标签的样本数,表示带标签的数据经过分布外概率计算模块的输出中第c行第1个元素,表示带标签的数据经过分布外概率计算模块的输出中第c′行第2个元素;
74、伪标签的生成损失lop为:
75、
76、vu1,zu1=max(spu1)
77、vu2,zu2=max(spu2)
78、其中,spu1表示无标签数据u1的伪标签,spu2表示无标签数据u2的伪标签,vu1以及zu1分别表示伪标签spu1概率分布中的最大概率以及最大概率所对应的类别索引,vu2以及zu2分别表示伪标签spu2概率分布中的最大概率以及最大概率所对应的类别索引,表示根据u1和u2得到的混合数据经开集分类器的预测结果,μ是混合率。
79、一种基于混合数据增强和开集半监督学习的超声心动图角度分类系统,所述系统包括超声心动图数据集获取模块、超声心动图预处理模块和神经网络模块;其中:
80、超声心动图数据集获取模块,用于获取带类别标签的超声心动图数据集以及不带类别标签的超声心动图数据集,所述类别标签为超声心动图拍摄的角度;
81、超声心动图预处理模块,用于对获取到的各个超声心动图分别进行预处理;
82、预处理的具体过程为:
83、对获取到的超声心动图进行随机水平翻转后,再进行任意大小以及任意方式的像素填充操作,最后通过随机剪裁获得各个尺寸为n×n的图像;
84、神经网络模块包括教师模型和学生模型,在教师模型的编码部分中,第三层与第四层之间添加有混合数据增强单元;在学生模型的编码部分中,添加开集半监督单元作为分类器的并行分支;
85、所述混合数据增强单元包括语义特征图获取模块、特征令牌获取模块、语义特征引导模块、混合掩码学习模块和混合数据生成模块;
86、所述语义特征图获取模块采用的是wideresnet、fastvit、repvit或视觉transformer模型;
87、在前向传播的过程中,保存模型的倒数第二层输出的特征作为提取出的语义特征图;
88、所述开集半监督单元包括分布内概率计算模块、分布外概率计算模块、超类伪标签生成模块和开集分类器;
89、教师模型和学生模型根据预处理后的图像进行端到端的联合训练,即每个批次数据训练完成后,均根据学生模型的参数来更新教师模型的参数,利用最终训练好的学生模型对待分类超声心动图进行角度分类;
90、教师模型的训练过程为:
91、对于任一批次数据:
92、步骤s41、对于当前批次中任意一个不带类别标签的超声心动图,超声心动图输入教师模型后,利用语义特征图获取模块输出语义特征图h是语义特征图x的高度,w是语义特征图x的宽度,c是语义特征图x的通道数;
93、特征令牌获取模块将语义特征图x分割成s×s个子特征图,且每个子特征图的大小均为将获得的各个子特征图作为语义特征图x的各个特征令牌;
94、步骤s42、利用语义特征引导模块对步骤s41获得的各个特征令牌分别进行处理,生成加权特征图;具体为:
95、步骤s421、采用1×1的卷积对各个特征令牌分别进行线性变换后,将第i个特征令牌对应的键、查询和值分别记为和
96、步骤s422、分别对各个特征令牌的查询和键进行平均池化操作,得到平均池化后第i个特征令牌对应的查询和键
97、步骤s423、对于第i个特征令牌,分别计算第i个特征令牌平均池化后的query′i与其它任意一个特征令牌平均池化后的key′k的点积,i≠k,将计算出的各点积结果作为权重矩阵p中的第i行元素,
98、同理,对每个特征令牌分别进行处理后,得到权重矩阵p;
99、步骤s424、从权重矩阵p中检索出前k'个最大的元素值所对应的(key′k,query′i)对,再根据检索出的(key′k,query′i)对获得(key′k,value′i)对;
100、步骤s425、对步骤s424中获得的各个(key′k,value′i)对进行收集拼接,得到拼接结果key′和value′;
101、key′=gather(key′k)
102、value′=gather(value′i)
103、其中,key′是对各个(key′k,value′i)对中的key′k进行收集拼接的结果,value′是对各个(key′k,value′i)对中的value′i进行收集拼接的结果,gather(·)表示收集拼接张量的函数,i=1,2,…,k′,k=1,2,…,k′;
104、步骤s426、采用缩放点积注意力,计算出经过注意力得分加权后的特征图xout:
105、
106、其中,采用1×1的卷积对特征图x进行线性变换得到query,是缩放因子;
107、步骤s43、采用步骤s41和步骤s42的方法对每个不带类别标签的超声心动图分别进行处理,得到每个不带类别标签的超声心动图对应的加权特征图;
108、步骤s44、利用混合掩码学习模块对任意两个不带类别标签的超声心动图对应的加权特征图进行处理,得到混合掩码:
109、
110、其中,和是两个不带类别标签的超声心动图对应的加权特征图,q表示经过1×1的卷积层后的输出,k表示经过1×1的卷积层后的输出,d是语义特征图的维度,t代表转置,h为常数,sigmod为激活函数,upsample代表上采样操作,mask是混合掩码,代表矩阵乘法;
111、步骤s45、混合数据生成模块利用混合掩码来生成混合数据xmix:
112、xmix=mask⊙x1+(1-mask)⊙x2
113、其中,x1表示对应的超声心动图,x2表示对应的超声心动图,⊙代表矩阵逐元素相乘;
114、步骤s46、当前批次中带类别标签的超声心动图数据和生成的混合数据经过教师模型后,根据教师模型的分类器输出来计算损失函数;
115、损失函数包括带类别标签超声心动图经过分类器的分类损失lce(c,z)和混合数据的生成损失lmix(x1,x2,xmix);
116、
117、其中,z=[p1,p2,p3,p4,...,pc’]为带类别标签超声心动图经过分类器输出的结果,pj表示超声心动图属于第j个类别的概率,c’表示已知类别的个数,c表示超声心动图的类别标签,pc表示分类器预测超声心动图属于类别标签c的概率;
118、
119、其中,cx1表示分类器输出的对x1的分类结果中,最大概率所对应的类别;cx2表示分类器输出的对x2的分类结果中,最大概率所对应的类别;表示分类器输出的对混合数据xmix的分类结果,表示混合数据xmix属于第1个,第2个,第3个,第4个,…,第c’个类别的概率;
120、直至损失函数收敛时,停止当前批次的训练;
121、学生模型的训练过程为:
122、对于任一批次数据:
123、步骤1、将当前批次数据经过教师模型所获得的混合数据以及当前批次数据输入到学生模型;
124、步骤2、利用分布内概率计算模块输出每个不带标签数据的分布内概率以及带标签数据的分布内概率;
125、分布内概率为:
126、pin=linear(avgpool(relu(bn(x))))
127、其中,x是分布内概率计算模块的输入,pin是x的分布内概率,bn为批归一化层,relu为激活函数,avgpool为平均池化层,linear为全连接层;
128、步骤3、利用分布外概率计算模块输出每个不带标签数据的分布外概率以及带标签数据的的分布外概率;
129、分布外概率计算模块是由c’个二分类器组合成的多头分类器,其中,c’表示已知标签的类别个数;将x经过分布外概率计算模块的输出记为pout,分布外概率
130、步骤4、对于任意一个不带标签的数据,利用数据的分布内概率和分布外概率计算开集伪标签;具体为:
131、步骤(1)、计算分布内概率和分布外概率的点积psp:
132、psp=pinpout
133、步骤(2)、计算类内每个类别所占的权重:
134、
135、其中,zj是分布内概率pin中,x属于第j个类别的概率,αj是第j个类别所占的权重;
136、步骤(3)、计算x对应的开集伪标签:
137、sp=[α1psp[0],α2psp[0],…,αc′psp[0],psp[1]]
138、其中,psp[0]是psp中的第1个元素,psp[1]是psp中的第2个元素;
139、步骤5、根据带标签数据的分布内概率以及带标签数据的的分布外概率计算开集的分类损失,根据步骤4获得的开集伪标签和混合数据经过开集分类器预测的分类结果计算伪标签的生成损失;
140、开集总损失包括伪标签的生成损失和开集的分类损失,即开集总损失l=esp+eop;
141、开集的分类损失esp为:
142、lsp=lin(c,z)+lout(pin,pout)
143、
144、其中,zb[c]表示带标签的数据经过分布内概率计算模块计算出的属于真实类别c的概率,zb[j]表示带标签的数据经过分布内概率计算模块计算出的属于类别j的概率,b表示每个批次中带标签的样本数,表示带标签的数据经过分布外概率计算模块的输出中第c行第1个元素,表示带标签的数据经过分布外概率计算模块的输出中第c′行第2个元素;
145、伪标签的生成损失lop为:
146、
147、vu1,zu1=max(spu1)
148、vu2,zu2=max(spu2)
149、其中,spu1表示无标签数据u1的伪标签,spu2表示无标签数据u2的伪标签,vu1以及zu1分别表示伪标签spu1概率分布中的最大概率以及最大概率所对应的类别索引,vu2以及zu2分别表示伪标签spu2概率分布中的最大概率以及最大概率所对应的类别索引,表示根据u1和u2得到的混合数据经开集分类器的预测结果,μ是混合率;
150、将开集分类损失和伪标签生成损失的和作为开集总损失,直至开集总损失收敛时,停止当前批次的训练。
151、本发明的有益效果是:
152、本发明基于混合数据增强技术,解决了目前方法难以高效准确地捕获超声心动图有效特征的问题,极大地提高了分类效果的鲁棒性。同时,通过引入开集伪标签的技术,解决了类外分布数据损害半监督算法分类效果的问题。本发明与传统的开集半监督算法检测并过滤分布外数据的思想不同,本发明所提出的方法通过计算不带标签数据属于分布内和分布外概率来构建开集伪标签,并利用开集伪标签参与模型训练,因此充分利用了未标记数据集的特征信息。此外,本发明方法创新性地提出了开集伪标签的混合损失,显著地提升了超声心动图角度分类的精准性。本发明提出的混合数据增强单元和开集分类单元共同训练的策略,实现了数据扩增和模型分类训练的同步化,从而使得训练的分类模型在实际临床应用中更具有可信性。本发明提供的开集半监督和混合数据增强结合的新型超声心动图分类框架,能够提高分类模型的训练效率,同时具有很好的精确性和鲁棒性,具有在临床中被广泛应用的潜力,同时相对于现有方法,本发明具有对类外分布数据检测的流程简单且检测效率高的特点。
技术研发人员:董素宇,马世舟,韩雪,董庆,孙一欣
技术所有人:东北林业大学
备 注:该技术已申请专利,仅供学习研究,如用于商业用途,请联系技术所有人。
声 明 :此信息收集于网络,如果你是此专利的发明人不想本网站收录此信息请联系我们,我们会在第一时间删除
