Skip to content

Commit

Permalink
update voc_annotation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing committed Apr 28, 2022
1 parent dc62f0d commit c7d8e13
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions voc_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import random
import xml.etree.ElementTree as ET

import numpy as np

from utils.utils import get_classes

#--------------------------------------------------------------------------------------------------------------------------------#
Expand Down Expand Up @@ -35,6 +37,10 @@
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
classes, _ = get_classes(classes_path)

#-------------------------------------------------------#
# 统计目标数量
#-------------------------------------------------------#,
nums = np.zeros(len(classes))
def convert_annotation(year, image_id, list_file):
in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
tree=ET.parse(in_file)
Expand All @@ -52,6 +58,8 @@ def convert_annotation(year, image_id, list_file):
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))

nums[classes.index(cls)] = nums[classes.index(cls)] + 1

if __name__ == "__main__":
random.seed(0)
if annotation_mode == 0 or annotation_mode == 1:
Expand Down Expand Up @@ -107,3 +115,29 @@ def convert_annotation(year, image_id, list_file):
list_file.write('\n')
list_file.close()
print("Generate 2007_train.txt and 2007_val.txt for train done.")

def printTable(List1, List2):
for i in range(len(List1[0])):
print("|", end=' ')
for j in range(len(List1)):
print(List1[j][i].rjust(int(List2[j])), end=' ')
print("|", end=' ')
print()

str_nums = [str(int(x)) for x in nums]
tableData = [
classes, str_nums
]
colWidths = [0]*len(tableData)
len1 = 0
for i in range(len(tableData)):
for j in range(len(tableData[i])):
if len(tableData[i][j]) > colWidths[i]:
colWidths[i] = len(tableData[i][j])
printTable(tableData, colWidths)

if np.sum(nums) == 0:
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
print("(重要的事情说三遍)。")

0 comments on commit c7d8e13

Please sign in to comment.