东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2825|回复: 0
打印 上一主题 下一主题

[课堂笔记] ArcFace loss的原理与代码实现_pytorch

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14438
QQ
跳转到指定楼层
楼主
发表于 2020-6-16 10:46:17 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式



ArcFace loss的原理与代码实现_pytorch




ArcFace伪代码实现步骤:
1、对特征向量x按行进行归一化
2、对权重W按列进行归一化
3、计算xW矩阵相乘得到预测向量y_pred
4、从y_pred中挑出与ground truth对应的值
5、计算其反余弦得到角度
6、角度加上m
7、得到从 y_pred 中挑出与ground truth对应的值所在位置的独热码
8、将 cos(theta+m) 通过独热码放回原来的位置
9、对所有值乘上固定值 s



三角函数的公式   cos(a+b) = cos(a)*cos(b) - sin(a) * sin(b)
        # cos(theta + m) = cos(theta) * cos(m)  - sin(theta) * sin(m)


  1. class Arcface(Module):
  2.     # https://zhuanlan.zhihu.com/p/76541084
  3.     # 问题:arcface在判断时如何设置余弦距离的阈值?难道就是那个margin(加性角度间隔)吗?默认=0.5
  4.     # 需要仔细研读:https://zhuanlan.zhihu.com/p/60747096
  5.     # https://zhuanlan.zhihu.com/p/62680658
  6.     # https://www.cnblogs.com/k7k8k91/p/9777148.html
  7.     # 从在数据集表现上来说,ArcFace优于其他几种loss,
  8.     # 著名的megaface赛事,在很长一段时间都停留在91%左右,在使用ArcFace提交后,准确率哗哗就提到了98%,
  9.     # 之后再刷的团队也大同小异,多使用ArcFace, CosineFace等损失函数。
  10.     # 这两个开源算法我也都跑过,粗略地比较了一下,可以见
  11.     # https://zhuanlan.zhihu.com/p/52560499
  12.     # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599   
  13.     def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5):
  14.         super(Arcface, self).__init__()
  15.         # classnum:训练数据集中共有多少个不同的人
  16.         print("Arcface是损失函数的实现:additive margin softmax loss")
  17.         self.classnum = classnum
  18.         print("Arcface头中:训练数据集共有多少个不同的人(分类):", self.classnum)
  19.         self.kernel = Parameter(torch.Tensor(embedding_size, classnum))
  20.         # initial kernel 初始化时就已经l2归一化了 这里是按列的
  21.         self.kernel.data.uniform_(-1, 1).renorm_(p=2, dim=1, maxnorm=1e-5).mul_(1e5)
  22.         self.m = m  # the margin value, default is 0.5
  23.         self.s = s  # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
  24.         print("Arcface头中:margin={}弧度制,scalar={}".format(m, s))
  25.         self.cos_m = math.cos(m)
  26.         self.sin_m = math.sin(m)
  27.         self.mm = self.sin_m * m  # issue 1
  28.         self.threshold = math.cos(math.pi - m)
  29.         print("Arcface头中:cos_m={},sin_m={},mm={},threshold={}".format(self.cos_m, self.sin_m, self.mm, self.threshold))

  30.     def forward(self, embbedings, label):
  31.         # weights norm
  32.         nB = len(embbedings)
  33.         # print("Arcface:前向计算函数中。。。batch_size=", nB)
  34.         # assert 0 == 1, "停"
  35.         # 对权重归一化后  模==1
  36.         # 输入的特征向量 已经在主干网中进行过l2_norm了
  37.         # 权重的归一化 注意dim=0 按列的
  38.         kernel_norm = l2_norm(self.kernel, axis=0)
  39.         # torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
  40.         # torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
  41.         # mm  matrix multi 矩阵相乘
  42.         # cos(theta+m)
  43.         # embbedings在主干网络已经L2归一化了  [batch_size, 512]   [512, 分类数目]
  44.         cos_theta = torch.mm(embbedings, kernel_norm)
  45.         # y_pred = torch.mm(embbedings, kernel_norm)
  46.         #         output = torch.mm(embbedings,kernel_norm)
  47.         # 夹紧,夹住 clamp  每个元素的范围限制到区间
  48.         cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
  49.         cos_theta_2 = torch.pow(cos_theta, 2)
  50.         sin_theta_2 = 1 - cos_theta_2
  51.         sin_theta = torch.sqrt(sin_theta_2)
  52.         #  反余弦得到角度后 加上 margin  三角函数的公式   cos(a+b) = cos(a)*cos(b) - sin(a) * sin(b)
  53.         # cos(theta + m) = cos(theta) * cos(m)  - sin(theta) * sin(m)
  54.         cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
  55.         # this condition controls the theta+m should in range [0, pi]
  56.         #      0<=theta+m<=pi
  57.         #     -m<=theta<=pi-m
  58.         cond_v = cos_theta - self.threshold
  59.         cond_mask = cond_v <= 0
  60.         keep_val = (cos_theta - self.mm)  # when theta not in [0,pi], use cosface instead
  61.         cos_theta_m[cond_mask] = keep_val[cond_mask]
  62.         output = cos_theta * 1.0  # a little bit hacky way to prevent in_place operation on cos_theta
  63.         idx_ = torch.arange(0, nB, dtype=torch.long)
  64.         output[idx_, label] = cos_theta_m[idx_, label]
  65.         output *= self.s  # scale up in order to make softmax work, first introduced in normface
  66.         return output
复制代码








让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-5-7 05:54 , Processed in 0.167666 second(s), 18 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表