Module earthvision.datasets.vision
Vision Dataset from torchvision/datasets/vision.py
Expand source code
"""Vision Dataset from torchvision/datasets/vision.py"""
import os
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.utils.data as data
class VisionDataset(data.Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
Args:
root (string): Root directory of dataset.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
.. note::
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
_repr_indent = 4
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}")
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError(
"Only transforms or transform/target_transform can " "be passed as argument"
)
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
def __getitem__(self, index: int) -> Any:
"""
Args:
index (int): Index
Returns:
(Any): Sample and meta data, optionally transformed by the respective transforms.
"""
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return "\n".join(lines)
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return ["{}{}".format(head, lines[0])] + [
"{}{}".format(" " * len(head), line) for line in lines[1:]
]
def extra_repr(self) -> str:
return ""
class StandardTransform(object):
def __init__(
self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
) -> None:
self.transform = transform
self.target_transform = target_transform
def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]:
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return ["{}{}".format(head, lines[0])] + [
"{}{}".format(" " * len(head), line) for line in lines[1:]
]
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform, "Target transform: ")
return "\n".join(body)
Classes
class StandardTransform (transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)
-
Expand source code
class StandardTransform(object): def __init__( self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None ) -> None: self.transform = transform self.target_transform = target_transform def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: if self.transform is not None: input = self.transform(input) if self.target_transform is not None: target = self.target_transform(target) return input, target def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() return ["{}{}".format(head, lines[0])] + [ "{}{}".format(" " * len(head), line) for line in lines[1:] ] def __repr__(self) -> str: body = [self.__class__.__name__] if self.transform is not None: body += self._format_transform_repr(self.transform, "Transform: ") if self.target_transform is not None: body += self._format_transform_repr(self.target_transform, "Target transform: ") return "\n".join(body)
class VisionDataset (root: str, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)
-
Base Class For making datasets which are compatible with torchvision. It is necessary to override the
__getitem__
and__len__
method.Args
root
:string
- Root directory of dataset.
transforms
:callable
, optional- A function/transforms that takes in an image and a label and returns the transformed versions of both.
transform
:callable
, optional- A function/transform that
takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
target_transform
:callable
, optional- A function/transform that takes in the target and transforms it.
Note
:attr:
transforms
and the combination of :attr:transform
and :attr:target_transform
are mutually exclusive.Expand source code
class VisionDataset(data.Dataset): """ Base Class For making datasets which are compatible with torchvision. It is necessary to override the ``__getitem__`` and ``__len__`` method. Args: root (string): Root directory of dataset. transforms (callable, optional): A function/transforms that takes in an image and a label and returns the transformed versions of both. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. """ _repr_indent = 4 def __init__( self, root: str, transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}") if isinstance(root, torch._six.string_classes): root = os.path.expanduser(root) self.root = root has_transforms = transforms is not None has_separate_transform = transform is not None or target_transform is not None if has_transforms and has_separate_transform: raise ValueError( "Only transforms or transform/target_transform can " "be passed as argument" ) # for backwards-compatibility self.transform = transform self.target_transform = target_transform if has_separate_transform: transforms = StandardTransform(transform, target_transform) self.transforms = transforms def __getitem__(self, index: int) -> Any: """ Args: index (int): Index Returns: (Any): Sample and meta data, optionally transformed by the respective transforms. """ raise NotImplementedError def __len__(self) -> int: raise NotImplementedError def __repr__(self) -> str: head = "Dataset " + self.__class__.__name__ body = ["Number of datapoints: {}".format(self.__len__())] if self.root is not None: body.append("Root location: {}".format(self.root)) body += self.extra_repr().splitlines() if hasattr(self, "transforms") and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [" " * self._repr_indent + line for line in body] return "\n".join(lines) def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() return ["{}{}".format(head, lines[0])] + [ "{}{}".format(" " * len(head), line) for line in lines[1:] ] def extra_repr(self) -> str: return ""
Ancestors
- torch.utils.data.dataset.Dataset
- typing.Generic
Subclasses
- AerialCactus
- Cloud38
- COWC
- DeepSat
- DroneDeploy
- EuroSat
- L7Irish
- L8Biome
- L8SPARCS
- LandCover
- RESISC45
- Sentinel2Cloud
- So2Sat
- SpaceNet7
- UCMercedLand
- XView
Class variables
var functions : Dict[str, Callable]
Methods
def extra_repr(self) ‑> str
-
Expand source code
def extra_repr(self) -> str: return ""