|
# 一键将数据集划分为:trainval、train、val、test4个
import os
import random
xmlfilepath = r'./VOCdevkit/VOC2007/Annotations'
saveBasePath = r"./VOCdevkit/VOC2007/ImageSets/Main/"
# 就是分数据集:train val test trainval 4个集合
trainval_percent = 0.9
train_percent = 0.8
temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:
if xml.endswith(".xml"):
total_xml.append(xml)
num = len(total_xml)
print("一共多少个标注文件:", num)
list = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list, tv)
train = random.sample(trainval, tr)
print("train and val size", tv)
print("train size", tr)
ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
for i in list:
name = total_xml[:-4] + '\n'
if i in trainval:
# 训练验证集
ftrainval.write(name)
if i in train:
# 训练集
ftrain.write(name)
else:
# 验证集
fval.write(name)
else:
# 测试集
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
|
|