Skip to content

Dataset

介绍

这里用垃圾分类来处理。

Dataset: 提供一种方法把数据加载进来,提供两个方法,分别实现获取数据和label、获取数据总量的功能。

Dataloader: 把数据打包交给模型训练、具体后面再提。

代码实战

这里用了一个二分类数据集:ants和bees

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
  def __init__(self,root_dir,label_dir):
    self.root_dir = root_dir
    self.label_dir = label_dir
    self.path = os.path.join(self.root_dir,self.label_dir)
    self.img_path = os.listdir(self.path)

  def __getitem__(self, idx): # 获取指定索引的样本
    img_name = self.img_path[idx]
    img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
    img = Image.open(img_item_path)
    label = self.label_dir
    return img,label

  def __len__(self):  # 获取样本总数
    return len(self.img_path)


if __name__ == '__main__':
    root_dir = "data\\train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)
    train_dataset = ants_dataset + bees_dataset
    print(len(train_dataset))
    img,label = train_dataset[200]
    img.show()

​ 包含图片及其label。