图像分类实战
1.导入包和下载数据集
import collections
import math
import os
import shutil
import pandas as pd
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
#@save
d2l.DATA_HUB['cifar10_tiny'] = (d2l.DATA_URL + 'kaggle_cifar10_tiny.zip',
'2068874e4b9a9f0fb07ebe0ad2b29754449ccacd')
# 如果使用完整的Kaggle竞赛的数据集,设置demo为False
demo = True
if demo:
data_dir = d2l.download_extract('cifar10_tiny')
else:
data_dir = '../data/cifar-10/'
2.整理数据集
将图片index与其target对应起来,存入字典中
rstrip()
:去除右边换行和空格
f.readlines()[1:]
:读取文件中的每一行,并把第一行去除,第一行一般是行号(没用
#@save
def read_csv_labels(fname):
"""读取fname来给标签字典返回一个文件名"""
with open(fname, 'r') as f:
# 跳过文件头行(列名)
lines = f.readlines()[1:]
tokens = [l.rstrip().split(',') for l in lines]
return dict(((name, label) for name, label in tokens))
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
labels
3.将验证集从原始数据集中拆分
令𝑛
等于样本最少的类别中的图像数量,而𝑟
是比率。
验证集将为每个类别拆分出 max(⌊𝑛𝑟⌋,1
) 张图像。
我们用最简单的方式拆分,创建对应target的文件夹,并把该target的图片存入该文件夹。
这样同类别的图像将被放置在同一文件夹下。
tips:
collections.Counter(labels.values())
:返回一个计数字典,统计了每个类的数量
print出来是这样Counter({'automobile': 112, 'frog': 107, 'truck': 103, ...})
most_common()
方法返回计数字典中出现频率最高的元素,[-1][1]拿到出现频率最低的标签对应的样本数。
然后我们去遍历train
中每一个文件,label = labels[train_file.split('.')[0]]
通过文件序号(也就是文件名)去先前处理好的字典中找到他对应的类。
copyfile(fname, os.path.join(data_dir, 'train_valid_test','train_valid', label))
将其拷贝到对应的类文件下。
if label not in label_count or label_count[label] < n_valid_per_label:
:检查当前样本的标签是否在label_count字典中或者当前标签在验证集中的样本数是否小于n_valid_per_label。如果是,则将当前样本复制到指定路径下的”valid”目录中,同时更新字典。
dict.get(key, value)
:若key不存在则返回value
#@save
def copyfile(filename, target_dir):
"""将文件复制到目标目录"""
os.makedirs(target_dir, exist_ok=True)
shutil.copy(filename, target_dir)
#@save
def reorg_train_valid(data_dir, labels, valid_ratio):
"""将验证集从原始的训练集中拆分出来"""
# 训练数据集中样本最少的类别中的样本数
n = collections.Counter(labels.values()).most_common()[-1][1]
# 验证集中每个类别的样本数
n_valid_per_label = max(1, math.floor(n * valid_ratio))
label_count = {}
for train_file in os.listdir(os.path.join(data_dir, 'train')):
label = labels[train_file.split('.')[0]]
fname = os.path.join(data_dir, 'train', train_file)
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train_valid', label))
if label not in label_count or label_count[label] < n_valid_per_label:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'valid', label))
label_count[label] = label_count.get(label, 0) + 1
else:
copyfile(fname, os.path.join(data_dir, 'train_valid_test',
'train', label))
return n_valid_per_label
4.调用
#@save
def reorg_test(data_dir):
"""在预测期间整理测试集,以方便读取"""
for test_file in os.listdir(os.path.join(data_dir, 'test')):
copyfile(os.path.join(data_dir, 'test', test_file),
os.path.join(data_dir, 'train_valid_test', 'test',
'unknown'))
def reorg_cifar10_data(data_dir, valid_ratio):
labels = read_csv_labels(os.path.join(data_dir, 'trainLabels.csv'))
reorg_train_valid(data_dir, labels, valid_ratio)
reorg_test(data_dir)
batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_cifar10_data(data_dir, valid_ratio)