200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > MixMatch文章解读+算法流程+核心代码详解

MixMatch文章解读+算法流程+核心代码详解

时间:2021-04-02 04:43:19

相关推荐

MixMatch文章解读+算法流程+核心代码详解

MixMatch

本博客仅做算法流程疏导,具体细节请参见原文

原文

查看原文点这里

Github代码

Github代码点这里

解读

MixMatch抓住了半监督算法的两个重要观点:第一是熵最小化;第二是一致性正则化。结合这两个观点的算法就形成了MixMatch。

熵最小化

半监督算法的一个常见假设就是分类的决策边界不应该通过数据分布的高密度区域。这句话简单的理解可以想象一个聚类模型,其决策边界一定是在簇与簇之间的稀疏边界上,不可能穿过一个簇的中心(高密度区域)。而实现这一点的一种方法就是要求分类器对未标记数据输出低熵预测。MixMatch中使用一个"sharpening"函数来隐式实现熵最小化。所谓熵最小化、低熵预测,都是指使输出概率分布比较有“偏向性”,而不希望输出一个“平均的预测”。熵在信息论中是不确定度的度量,根据离散模型的熵最大定理,可知在均匀分布时熵取得最大值,换句话说,出现一个确定的分布,即某一类的概率是1,其余类的概率是0时,熵为0。也就是说想要得到熵最小,就得使分类器输出后的模型预测概率集中分配给某一类。后面再介绍“sharpening”函数如何实现这一点。

一致性正则化

一致性正则化也是一个常见的半监督假设。VAT、MeanTeacher等其实都或多或少使用了这种假设。其核心在于,我们希望一个样本和其加扰版本(通常图像中称为Augment)通过分类器后,得到相似的输出。其实也就是说分类边界不应该穿过数据分布的高密度区域。如下图,红色点是原始样本,蓝色和绿色为其扰动版本,红色同心圆的虚线圆是我们期望的容差范围,即在这个区间类的都应该认为和其中心数据点为同一类。通过扰动数据点的加入,将决策边界推到合适的位置,使分类器的鲁棒性更强。

一般而言,通过对原始样本和其扰动版本的分类器输出进行衡量,即可实现一致性正则化,常见的衡量方式有MSE、KL散度、JS散度等。在MixMatch中通过对图像的标准数据增强(水平翻转、裁剪)实现扰动(Augment),采用MSE准则方式衡量。

总得来说,算法有以下步骤:

归结而言有五个步骤:

第一步,对数据进行扩增(Augment)。扩增分为对有标记数据集 X X X​的扩增和对无标记数据集 U U U​的扩增,分别记为 X ^ \hat{X} X^​和 U ^ \hat{U} U^​。对 X X X​扩增一次,对 U U U​扩增 K K K​次,文章中取 K = 2 K=2 K=2​。因为在取batch时, B a t c h S i z e U = B a t c h S i z e X Batch Size _U = BatchSize_X BatchSizeU​=BatchSizeX​​,所以扩增后 B a t c h S i z e U ^ = K ⋅ B a t c h S i z e X ^ Batch Size _{\hat{U}} = K\cdot BatchSize_{\hat{X}} BatchSizeU^​=K⋅BatchSizeX^​​​​​。

第二步,计算平均预测分布。此步骤仅对数据集 U ^ \hat{U} U^​​​进行。即通过如下公式计算,其中 ( u b , k ^ , y ) (\hat{u_{b,k}},y) (ub,k​^​,y)​是 U ^ \hat{U} U^​的一个 B a t c h Batch Batch​:

q b ˉ = 1 K ∑ k P m o d e l ( y ∣ u b , k ^ ; θ ) \bar{q_b}=\frac{1}{K}\sum_kP_{model}(y|\hat{u_{b,k}};\theta) qb​ˉ​=K1​k∑​Pmodel​(y∣ub,k​^​;θ)

值得注意的是, P m o d e l ( y ∣ u b , k ^ ; θ ) P_{model}(y|\hat{u_{b,k}};\theta) Pmodel​(y∣ub,k​^​;θ)是 S o f t m a x Softmax Softmax​之后的预测概率分布。

第三步,通过 s h a r p e n i n g sharpening sharpening函数完成分布的锐化,其计算公式如下:

S h a r p e n ( p , T ) i = p i 1 T ∑ j = 1 L p j 1 T Sharpen(p,T)_i=\frac{p_i^{\frac{1}{T}}}{\sum^L_{j=1}p_j^{\frac{1}{T}}} Sharpen(p,T)i​=∑j=1L​pjT1​​piT1​​​​

当超参数 T → 0 T\to 0 T→0​时, S h a r p e n ( p , T ) Sharpen(p,T) Sharpen(p,T)​趋向于 o n e − h o t one-hot one−hot​​分布,即其中一个类别的概率为1,其余概率为0;锐化后的概率分布作为 U ^ \hat{U} U^​的数据标签(pseudo label)。

第四步,通过 M i x U p MixUp MixUp完成新数据集的构建。先将第一步扩增后的 X ^ \hat{X} X^和 U ^ \hat{U} U^进行拼接再打乱顺序,得到 W = S h u f f l e ( C o n c a t ( X ^ , U ^ ) ) W=Shuffle(Concat(\hat{X},\hat{U})) W=Shuffle(Concat(X^,U^)),然后再将 W W W分为两部分,第一部分大小与 X ^ \hat{X} X^相同(也与 X X X相同),记为 W x W_x Wx​;第二部分大小与 U ^ \hat{U} U^相同(也与 U U U相同),记为 W u W_u Wu​。然后将 W x W_x Wx​和 X ^ \hat{X} X^进行 M i x U p MixUp MixUp, W u W_u Wu​和 U ^ \hat{U} U^进行 M i x U p MixUp MixUp,得到 X ′ X' X′和 U ′ U' U′​。 M i x U p MixUp MixUp步骤如下:

λ ∼ B e t a ( α , α ) \lambda\sim Beta(\alpha,\alpha) λ∼Beta(α,α)

λ ′ = m a x ( λ , 1 − λ ) \lambda'=max(\lambda,1-\lambda) λ′=max(λ,1−λ)

x ′ = λ ′ x 1 + ( 1 − λ ′ ) x 2 x'=\lambda'x_1+(1-\lambda')x_2 x′=λ′x1​+(1−λ′)x2​

p ′ = λ ′ p 1 + ( 1 − λ ′ ) p 2 p'=\lambda'p_1+(1-\lambda')p_2 p′=λ′p1​+(1−λ′)p2​

第五步,计算半监督损失函数,分为在标记数据集 X ′ X' X′​上的损失函数 L x L_x Lx​和在无标记数据集 U ′ U' U′上的损失函数 L u L_u Lu​,公式如下:

L x = 1 ∣ X ′ ∣ ∑ x , p ∈ X ′ H ( p , P m o d e l ( y ∣ x ; θ ) ) L_x=\frac{1}{|X'|}\sum_{x,p\in X'}H(p,P_{model}(y|x;\theta)) Lx​=∣X′∣1​x,p∈X′∑​H(p,Pmodel​(y∣x;θ))

L u = 1 L ∣ U ′ ∣ ∑ u , q ∈ U ′ ∣ ∣ q − P m o d e l ( y ∣ u ; θ ) ∣ ∣ 2 2 L_u=\frac{1}{L|U'|}\sum_{u,q\in U'}||q-P_{model}(y|u;\theta)||^2_2 Lu​=L∣U′∣1​u,q∈U′∑​∣∣q−Pmodel​(y∣u;θ)∣∣22​

L = L x + λ U L u L=L_x+\lambda_UL_u L=Lx​+λU​Lu​

其中 H ( ⋅ ) H(\cdot) H(⋅)​是 C o r s s E n t r o p y L o s s CorssEntropyLoss CorssEntropyLoss​; L u L_u Lu​其实就是 M S E MSE MSE准则下的误差项。​

反向梯度传播即可完成整个MixMatch算法

核心代码详解

图像的水平翻转、裁剪实现 A u g m e n t Augment Augment:

transform_train = pose([dataset.RandomPadandCrop(32),dataset.RandomFlip(),dataset.ToTensor(),])transform_val = pose([dataset.ToTensor(),])

这里是在迭代过程中,手动取迭代器中的batch,而不是直接使用Dataloader。这种做法在最近的几篇文章代码复现中都遇见了,其主要目的是为了在一个epoch中可以迭代指定次数,而直接使用Dataloader只能迭代最多 c e i l ( 样 本 总 数 B a t c h S i z e ) ceil(\frac{样本总数}{BatchSize}) ceil(BatchSize样本总数​)次,其中 c e i l ( ⋅ ) ceil(\cdot) ceil(⋅)是上取整函数,如果 d r o p l a s t drop_last dropl​ast,则只能迭代 样 本 总 数 B a t c h S i z e \frac{样本总数}{BatchSize} BatchSize样本总数​次。代码中的两个try except是为了保证迭代器完全迭代一次后,重新加载迭代器,继续迭代,直到达到指定次数才跳转下一个epoch。

for batch_idx in range(args.train_iteration):try:inputs_x, targets_x = labeled_train_iter.next()except:labeled_train_iter = iter(labeled_trainloader)inputs_x, targets_x = labeled_train_iter.next()try:(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()except:unlabeled_train_iter = iter(unlabeled_trainloader)(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()

因为文章中取 K = 2 K=2 K=2,所以进行两次扩增,求输出概率的均值,其中output_uoutput_u2分别为两次扩增后的模型输出结果:

outputs_u = model(inputs_u)outputs_u2 = model(inputs_u2)p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 # 求两次的平均值

求Sharpening结果:

pt = p**(1/args.T)targets_u = pt / pt.sum(dim=1, keepdim=True)targets_u = targets_u.detach()

完成 M i x U p MixUp MixUp:

all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)l = np.random.beta(args.alpha, args.alpha)l = max(l, 1-l)idx = torch.randperm(all_inputs.size(0))input_a, input_b = all_inputs, all_inputs[idx]target_a, target_b = all_targets, all_targets[idx]mixed_input = l * input_a + (1 - l) * input_bmixed_target = l * target_a + (1 - l) * target_b

然后计算损失函数:

logits = [model(mixed_input[0])]for input in mixed_input[1:]:logits.append(model(input))# put interleaved samples backlogits = interleave(logits, batch_size)logits_x = logits[0]logits_u = torch.cat(logits[1:], dim=0)Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:], epoch+batch_idx/args.train_iteration)loss = Lx + w * Lu

反向梯度传播,结束。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。