Source code for cleanvision.dataset.torch_dataset
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from PIL import Image
from cleanvision.dataset.base_dataset import Dataset
if TYPE_CHECKING: # pragma: no cover
from torchvision.datasets.vision import VisionDataset
[docs]class TorchDataset(Dataset):
"""Wrapper class to handle datasets loaded from torchvision."""
def __init__(self, torch_dataset: "VisionDataset") -> None:
super().__init__()
self._data = torch_dataset
# todo: catch errors
for i, obj in enumerate(torch_dataset[0]):
if isinstance(obj, Image.Image):
self._image_idx = i
self._set_index()
def __len__(self) -> int:
return len(self._data)
def __getitem__(self, item: Union[int, str]) -> Image.Image:
return self._data[item][self._image_idx]
def _set_index(self) -> None:
self.index = [i for i in range(len(self._data))]