반응형

파이썬(python)에서 사용할 수 있는 파이토치(PyTorch)는 간단하게 딥 러닝을 구현할 수 있는 좋은 라이브러리입니다.

오늘은 파이토치(PyTorch)에서 딥 러닝에서 많이 사용되는 데이터 정보를 관리할 수 있는 데이터 세트(Dataset) 및 데이터 로더(Dataloader)에 대해서 알아보겠습니다.

배열과 같은 데이터 정보는 단순한 형태를 사용하기 편리하지만, 정보가 증가하면 유지 관리하기 어렵습니다.

데이터 세트(Dataset)는 가독성을 높이면서 데이터를 쉽게 액세스 할 수 있도록 도와줍니다.

파이토치(PyTorch) 데이터 세트(Dataset)를 사용하기 위해서는 torchvision 패키지를 인스톨해야 합니다.

torchvision 패키지 인스톨을 진행하면 java SDK를 업데이트합니다.

설치 화면에서 설치를 클릭해주세요.

 

데이터 세트(Dataset)는 기본적으로 사용할 수 있는 데이터 정보를 다운로드하여 저장할 수 있습니다.

training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

FashionMNIST 메서드를 사용해서 이미지 정보를 training_data에 저장합니다.

FashionMNIST 메서드 root는 테스트 데이터 저장 경로입니다.

train은 데이터 세트 지정 정보입니다.

download는 데이터 다운로드 설정 정보입니다.

transform은 레이블 변환 지정입니다.

코드를 실행하면 출력 창에서 이미지 다운로드 화면을 확인할 수 있습니다.

training_data에 저장된 데이터 세트(Dataset) 정보를 출력해보겠습니다.

def showimage():
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    labels_map = {
        0: "T-Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels_map[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()

배열에 저장된 9개의 이미지 정보를 기준으로 데이터 세트(Dataset)에서 이미지를 확인 후 matplotlib 패키지 plt를 사용해서 이미지로 출력합니다.

출력 결과 9개의 이미지가 순차적으로 출력됩니다.

데이터 세트(Dataset)는 예제 정보를 다운로드할 수 있어 간단하게 이미지 관련 딥러닝을 실행할 수 있습니다.

 

데이터 세트(Dataset)는 배열처럼 사용하는 방법보다는 클래스로 구현하면 더욱더 쉽게 사용할 수 있습니다.

기복적으로 __init__, __len__, __getitem__ 세 가지를 함수를 추가 구현할 수 있습니다.

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

클래스에 사용할 수 있는 기본 함수는 자동 생성되기 때문에 필요한 부분을 추가 구현할 수 있습니다.

__init__ 함수는 데이터 세트(Dataset) 객체를 인스턴스화 할 때 한번 실행됩니다.

기본적으로 초기화에 필요한 정보를 입력할 수 있습니다.

__len__ 함수는 데이터 세트(Dataset) 객체 개수를 반환합니다.

__getitem__ 함수는 선택한 인덱스의 데이터 정보를 로드하여 반환합니다.

변환 정보는 처음 배운 텐서로 변환되어 사용할 수 있습니다.

 

데이터 로더(Dataloader)는 데이터 세트 클래스의 확장 형태로 간단하게 데이터 세트를 접근할 수 있습니다.

간단하게 말해서 데이터 로더(Dataloader)는 데이터 세트의 복잡성을 간단한 API로 추상화한 클래스입니다.

데이터 로더(Dataloader)를 사용하기 위해서 torch.utils.data 패키지를 설치합니다.

def DataLoaderEx():
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

    train_features, train_labels = next(iter(train_dataloader))
    print(f"Feature batch shape: {train_features.size()}")
    print(f"Labels batch shape: {train_labels.size()}")
    img = train_features[1].squeeze()
    label = train_labels[1]
    plt.imshow(img, cmap="gray")
    plt.show()
    print(f"Label: {label}")

데이터 로더(Dataloader)를 로드된 데이터 세트를 생성자에 대입만 하면 바로 사용이 가능합니다.

데이터 로더(Dataloader)는 배열 사용과 동일하게 인덱스를 입력하면 바로 이미지를 확인할 수 있습니다.

데이터 로더(Dataloader) 인덱스 정보를 변경하면 다음 이미지를 확인할 수 있습니다.

파이토치(PyTorch)  데이터 세트(Dataset) 및 데이터 로더(Dataloader)는 이미지 정보와 같은 복잡한 정보를 누구가 쉽게 사용할 수 있도록 구현된 객체입니다.

기본 정보를 다운로드하여 객체에 저장하기 때문에 별도 이미지를 찾을 필요가 없습니다.

딥 러닝은 코드 보다도 알고리즘이 중요하기 때문에 코드를 매우 단순화할 수 있는 좋은 패키지라고 생각됩니다.

감사합니다.

https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

 

Quickstart — PyTorch Tutorials 1.9.0+cu102 documentation

Note Click here to download the full example code Learn the Basics || Quickstart || Tensors || Datasets & DataLoaders || Transforms || Build Model || Autograd || Optimization || Save & Load Model Quickstart This section runs through the API for common task

pytorch.org

 

반응형

+ Recent posts