|
ArcFace loss的原理与代码实现_pytorch
ArcFace伪代码实现步骤:
1、对特征向量x按行进行归一化
2、对权重W按列进行归一化
3、计算xW矩阵相乘得到预测向量y_pred
4、从y_pred中挑出与ground truth对应的值
5、计算其反余弦得到角度
6、角度加上m
7、得到从 y_pred 中挑出与ground truth对应的值所在位置的独热码
8、将 cos(theta+m) 通过独热码放回原来的位置
9、对所有值乘上固定值 s
三角函数的公式 cos(a+b) = cos(a)*cos(b) - sin(a) * sin(b)
# cos(theta + m) = cos(theta) * cos(m) - sin(theta) * sin(m)
- class Arcface(Module):
- # https://zhuanlan.zhihu.com/p/76541084
- # 问题:arcface在判断时如何设置余弦距离的阈值?难道就是那个margin(加性角度间隔)吗?默认=0.5
- # 需要仔细研读:https://zhuanlan.zhihu.com/p/60747096
- # https://zhuanlan.zhihu.com/p/62680658
- # https://www.cnblogs.com/k7k8k91/p/9777148.html
- # 从在数据集表现上来说,ArcFace优于其他几种loss,
- # 著名的megaface赛事,在很长一段时间都停留在91%左右,在使用ArcFace提交后,准确率哗哗就提到了98%,
- # 之后再刷的团队也大同小异,多使用ArcFace, CosineFace等损失函数。
- # 这两个开源算法我也都跑过,粗略地比较了一下,可以见
- # https://zhuanlan.zhihu.com/p/52560499
- # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
- def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5):
- super(Arcface, self).__init__()
- # classnum:训练数据集中共有多少个不同的人
- print("Arcface是损失函数的实现:additive margin softmax loss")
- self.classnum = classnum
- print("Arcface头中:训练数据集共有多少个不同的人(分类):", self.classnum)
- self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
- # initial kernel 初始化时就已经l2归一化了 这里是按列的
- self.kernel.data.uniform_(-1, 1).renorm_(p=2, dim=1, maxnorm=1e-5).mul_(1e5)
- self.m = m # the margin value, default is 0.5
- self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
- print("Arcface头中:margin={}弧度制,scalar={}".format(m, s))
- self.cos_m = math.cos(m)
- self.sin_m = math.sin(m)
- self.mm = self.sin_m * m # issue 1
- self.threshold = math.cos(math.pi - m)
- print("Arcface头中:cos_m={},sin_m={},mm={},threshold={}".format(self.cos_m, self.sin_m, self.mm, self.threshold))
- def forward(self, embbedings, label):
- # weights norm
- nB = len(embbedings)
- # print("Arcface:前向计算函数中。。。batch_size=", nB)
- # assert 0 == 1, "停"
- # 对权重归一化后 模==1
- # 输入的特征向量 已经在主干网中进行过l2_norm了
- # 权重的归一化 注意dim=0 按列的
- kernel_norm = l2_norm(self.kernel, axis=0)
- # torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
- # torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
- # mm matrix multi 矩阵相乘
- # cos(theta+m)
- # embbedings在主干网络已经L2归一化了 [batch_size, 512] [512, 分类数目]
- cos_theta = torch.mm(embbedings, kernel_norm)
- # y_pred = torch.mm(embbedings, kernel_norm)
- # output = torch.mm(embbedings,kernel_norm)
- # 夹紧,夹住 clamp 每个元素的范围限制到区间
- cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
- cos_theta_2 = torch.pow(cos_theta, 2)
- sin_theta_2 = 1 - cos_theta_2
- sin_theta = torch.sqrt(sin_theta_2)
- # 反余弦得到角度后 加上 margin 三角函数的公式 cos(a+b) = cos(a)*cos(b) - sin(a) * sin(b)
- # cos(theta + m) = cos(theta) * cos(m) - sin(theta) * sin(m)
- cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
- # this condition controls the theta+m should in range [0, pi]
- # 0<=theta+m<=pi
- # -m<=theta<=pi-m
- cond_v = cos_theta - self.threshold
- cond_mask = cond_v <= 0
- keep_val = (cos_theta - self.mm) # when theta not in [0,pi], use cosface instead
- cos_theta_m[cond_mask] = keep_val[cond_mask]
- output = cos_theta * 1.0 # a little bit hacky way to prevent in_place operation on cos_theta
- idx_ = torch.arange(0, nB, dtype=torch.long)
- output[idx_, label] = cos_theta_m[idx_, label]
- output *= self.s # scale up in order to make softmax work, first introduced in normface
- return output
复制代码
|
|