KimbgAI

[ML] STL10 Dataset 간편하게 다운받고 정리하기 본문

machine learning

[ML] STL10 Dataset 간편하게 다운받고 정리하기

KimbgAI 2022. 10. 6. 14:41
반응형

오픈 데이터셋 중 CIFAR10(32x32)보다 고해상도의 간단한 Classification task를 위한 데이터셋을 찾다가 스탠포드에서 제공해주는 STL10이라는 데이터셋을 찾을 수 있었다.

STL10 이미지 예시


Data overview

 - train set: 5,000, test set: 8,000, unlabelled: 100,000

 - 10 classes: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck
 - Images are 96x96 pixels, color
 - 500 training images per class (10 pre-defined folds)

 

데이터에 관한 더 자세한 내용은 아래 링크를 통해서 확인할 수 있다.
https://ai.stanford.edu/~acoates/stl10/


아래 코드를 실행하면 현재 폴더에서 train, test, unlabelled 폴더로 구성된 데이터셋을 간편하게 받을 수 있다.

원제공자는 데이터를 binary 파일형태로 제공해주기 때문에, 이 코드를 실행하면 binary 파일을 이미지 파일로 변환해주고, class에 따라 폴더를 구분할 수 있어 편리하다.

import os, sys, tarfile, errno
import numpy as np
import matplotlib.pyplot as plt

from imageio import imsave
from tqdm import tqdm
import random
import shutil
    
if sys.version_info >= (3, 0, 0):
    import urllib.request as urllib # ugly but works
else:
    import urllib

HEIGHT = 96
WIDTH = 96
DEPTH = 3

SIZE = HEIGHT * WIDTH * DEPTH

DATA_DIR = './'
DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'

TRAIN_DATA_PATH = './stl10_binary/train_X.bin'
TRAIN_LABEL_PATH = './stl10_binary/train_y.bin'

TEST_DATA_PATH = './stl10_binary/test_X.bin'
TEST_LABEL_PATH = './stl10_binary/test_y.bin'

UNLAB_DATA_PATH = './stl10_binary/unlabeled_X.bin'

def download_and_extract():
    """
    Download and extract the STL-10 dataset
    :return: None
    """
    dest_directory = DATA_DIR
    if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
    filename = DATA_URL.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)
    if not os.path.exists(filepath):
        def _progress(count, block_size, total_size):
            sys.stdout.write('\rDownloading %s %.2f%%' % (filename,
                float(count * block_size) / float(total_size) * 100.0))
            sys.stdout.flush()
        filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
        print('Downloaded', filename)
        tarfile.open(filepath, 'r:gz').extractall(dest_directory)

def read_labels(path_to_labels):

    with open(path_to_labels, 'rb') as f:
        labels = np.fromfile(f, dtype=np.uint8)
        return labels

def read_all_images(path_to_data):

    with open(path_to_data, 'rb') as f:
        everything = np.fromfile(f, dtype=np.uint8)

        images = np.reshape(everything, (-1, 3, 96, 96))
        images = np.transpose(images, (0, 3, 2, 1))
        return images

def save_image(image, name):
    imsave("%s.png" % name, image, format="png")

def save_images(images, labels, types):
    i = 0
    for image in tqdm(images, position=0):
        label = labels[i] 
        directory = DATA_DIR + '/' + types + '/' + str(label) + '/'
        try:
            os.makedirs(directory, exist_ok=True)
        except OSError as exc:
            if exc.errno == errno.EEXIST:
                pass
        filename = directory + str(i)
        #print(filename)
        save_image(image, filename)
        i = i+1
        
def save_unlabelled_images(images):
    i = 0
    for image in tqdm(images, position=0):
        directory = DATA_DIR + '/' + 'unlabelled' + '/'
        try:
            os.makedirs(directory, exist_ok=True)
        except OSError as exc:
            if exc.errno == errno.EEXIST:
                pass
        filename = directory + str(i)
        save_image(image, filename)
        i = i+1 
        

def create_val_dataset():
    train_image_path = DATA_DIR + "train"
    folders = os.listdir(train_image_path)

    for folder in tqdm(folders, position=0):
        temp_dir = DATA_DIR +"/train/" + folder
        temp_image_list = os.listdir(temp_dir)

    for i in range(50):
        val_dir = DATA_DIR + "/val/" + folder
        try:
            os.makedirs(val_dir, exist_ok=True)
        except OSError as exc:

            if exc.errno == errno.EEXIST:
                pass
        image_name = random.choice(temp_image_list)
        temp_image_list.remove(image_name)
        old_name = temp_dir + '/' + image_name
        new_name = val_dir + '/' + image_name
        os.replace(old_name, new_name)


download_and_extract()

train_labels = read_labels(TRAIN_LABEL_PATH)
train_images = read_all_images(TRAIN_DATA_PATH)

test_labels = read_labels(TEST_LABEL_PATH)
test_images = read_all_images(TEST_DATA_PATH)

unlabelled_images = read_all_images(UNLAB_DATA_PATH)

save_images(train_images, train_labels, "train")
save_images(test_images, test_labels, "test")
save_unlabelled_images(unlabelled_images)

본 코드는 아래 링크를 참고하여 커스텀하였다.
https://github.com/mttk/STL10/blob/master/stl10_input.py
https://www.kaggle.com/code/pratt3000/generate-stl10/notebook

반응형
Comments