东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 7803|回复: 8
打印 上一主题 下一主题

[课堂笔记] 利用IOU值进行K-means聚类,遗传算法得出最佳的anchor

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14429
QQ
跳转到指定楼层
楼主
发表于 2020-11-23 11:17:42 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式


import numpy as np
import torch
from tqdm import tqdm
from scipy.cluster.vq import kmeans


from utils.datasets import LoadImagesAndLabels


# 类似遗传算法的方法,通过一代一代的筛选找到合适的Anchor
# 利用IOU值进行K-means聚类,遗传算法得出最佳的anchor




def wh_iou(wh1, wh2):
    # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
    wh1 = wh1[:, None]  # [N,1,2]
    wh2 = wh2[None]  # [1,M,2]
    inter = torch.min(wh1, wh2).prod(2)  # [N,M]
    return inter / (wh1.prod(2) + wh2.prod(2) - inter)  # iou = inter / (area1 + area2 - inter)




def kmean_anchors(path='./data/coco64.txt', n=9, img_size=(608, 608), thr=0.20, gen=1000):
    # Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
    # n: number of anchors
    # img_size: (min, max) image size used for multi-scale training (can be same values)
    # thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
    # gen: generations to evolve anchors using genetic algorithm
    # 遗传算法进化几代




    def print_results(k):
        k = k[np.argsort(k.prod(1))]  # sort small to large
        iou = wh_iou(wh, torch.Tensor(k))
        max_iou = iou.max(1)[0]
        bpr, aat = (max_iou > thr).float().mean(), (iou > thr).float().mean() * n  # best possible recall, anch > thr
        print('%.2f iou_thr: %.3f best possible recall, %.2f anchors > thr' % (thr, bpr, aat))
        print('n=%g, img_size=%s, IoU_all=%.3f/%.3f-mean/best, IoU>thr=%.3f-mean: ' %
              (n, img_size, iou.mean(), max_iou.mean(), iou[iou > thr].mean()), end='')
        for i, x in enumerate(k):
            print('%i,%i' % (round(x[0]), round(x[1])), end=',  ' if i < len(k) - 1 else '\n')  # use in *.cfg
        return k


    def fitness(k):  # mutation fitness
        iou = wh_iou(wh, torch.Tensor(k))  # iou
        max_iou = iou.max(1)[0]
        return (max_iou * (max_iou > thr).float()).mean()  # product


    # Get label wh
    wh = []
    dataset = LoadImagesAndLabels(path, augment=True, rect=True)
    nr = 1 if img_size[0] == img_size[1] else 10  # number augmentation repetitions
    for s, l in zip(dataset.shapes, dataset.labels):
        wh.append(l[:, 3:5] * (s / s.max()))  # image normalized to letterbox normalized wh
    wh = np.concatenate(wh, 0).repeat(nr, axis=0)  # augment 10x
    wh *= np.random.uniform(img_size[0], img_size[1], size=(wh.shape[0], 1))  # normalized to pixels (multi-scale)
    wh = wh[(wh > 2.0).all(1)]  # remove below threshold boxes (< 2 pixels wh)


    # Kmeans calculation


    print('开始聚类Running kmeans for %g anchors on %g points...' % (n, len(wh)))
    s = wh.std(0)  # sigmas for whitening
    k, dist = kmeans(wh / s, n, iter=30)  # points, mean distance
    k *= s
    wh = torch.Tensor(wh)
    k = print_results(k)


    # # Plot
    # k, d = [None] * 20, [None] * 20
    # for i in tqdm(range(1, 21)):
    #     k[i-1], d[i-1] = kmeans(wh / s, i)  # points, mean distance
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7))
    # ax = ax.ravel()
    # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
    # fig, ax = plt.subplots(1, 2, figsize=(14, 7))  # plot wh
    # ax[0].hist(wh[wh[:, 0]<100, 0],400)
    # ax[1].hist(wh[wh[:, 1]<100, 1],400)
    # fig.tight_layout()
    # fig.savefig('wh.png', dpi=200)


    # Evolve
    npr = np.random
    f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1  # fitness, generations, mutation prob, sigma
    for _ in tqdm(range(gen), desc='Evolving anchors'):
        v = np.ones(sh)
        while (v == 1).all():  # mutate until a change occurs (prevent duplicates)
            v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
        kg = (k.copy() * v).clip(min=2.0)
        fg = fitness(kg)
        if fg > f:
            f, k = fg, kg.copy()
            print_results(k)


    k = print_results(k)


    return k




result = kmean_anchors(path="/home/jiang/dfy_darknet_works/darknet-master/backup_person_vehicle/train_mini_balance.txt")
print("最终结果:", result)


让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

0

主题

102

帖子

208

积分

中级会员

Rank: 3Rank: 3

积分
208
沙发
发表于 2020-12-6 09:27:32 | 只看该作者

笔记写的挺好的,点个赞
回复

使用道具 举报

0

主题

102

帖子

208

积分

中级会员

Rank: 3Rank: 3

积分
208
板凳
发表于 2020-12-6 09:37:20 | 只看该作者
目标检测之coco与voc格式的数据相互转换,并验证coco格式的脚本目标检测之coco与voc格式的数据相互转换,并验证coco格式的脚本
回复

使用道具 举报

0

主题

98

帖子

206

积分

中级会员

Rank: 3Rank: 3

积分
206
地板
发表于 2021-1-12 16:33:57 | 只看该作者
nameduohaodongx
回复

使用道具 举报

0

主题

98

帖子

206

积分

中级会员

Rank: 3Rank: 3

积分
206
5#
发表于 2021-1-12 16:34:22 | 只看该作者
虽然现在看不懂,但是肯定会的
回复

使用道具 举报

0

主题

98

帖子

206

积分

中级会员

Rank: 3Rank: 3

积分
206
6#
发表于 2021-1-12 16:34:40 | 只看该作者
66666666666666666666666666666666666
回复

使用道具 举报

0

主题

98

帖子

206

积分

中级会员

Rank: 3Rank: 3

积分
206
7#
发表于 2021-1-12 16:35:11 | 只看该作者
每天就只有30次吗
回复

使用道具 举报

0

主题

98

帖子

206

积分

中级会员

Rank: 3Rank: 3

积分
206
8#
发表于 2021-1-12 16:35:27 | 只看该作者
666666666666666666666666666666666666666666666
回复

使用道具 举报

0

主题

98

帖子

206

积分

中级会员

Rank: 3Rank: 3

积分
206
9#
发表于 2021-1-12 16:35:44 | 只看该作者
hao  66666666666666666666666666666666
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-19 13:21 , Processed in 0.202161 second(s), 19 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表