Module earthvision.models.resisc45.utils
Expand source code
import os
import warnings
import torch
import gdown
ENV_TORCH_HOME = "TORCH_HOME"
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
DEFAULT_CACHE_DIR = "~/.cache"
_hub_dir = None
def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(
ENV_TORCH_HOME, os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch")
)
)
return torch_home
def get_dir():
r"""
Get the Torch Hub cache directory used for storing downloaded models & weights.
If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
filesystem layout, with a default value ``~/.cache`` if the environment
variable is not set.
"""
# Issue warning to move data if old env is set
if os.getenv("TORCH_HUB"):
warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
if _hub_dir is not None:
return _hub_dir
return os.path.join(_get_torch_home(), "hub")
def set_dir(d):
r"""
Optionally set the Torch Hub directory used to save downloaded models & weights.
Args:
d (string): path to a local folder to save downloaded models & weights.
"""
global _hub_dir
_hub_dir = d
def load_state_dict_from_url(url, model_dir=None, map_location=None):
r"""Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically
decompressed.
If the object is already present in `model_dir`, it's deserialized and
returned.
The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where
``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
cached_file = os.path.join(model_dir, url[1])
if not os.path.exists(cached_file):
gdown.download(url[0], cached_file, quiet=False)
return torch.load(cached_file, map_location=map_location)
Functions
def get_dir()
-
Get the Torch Hub cache directory used for storing downloaded models & weights. If :func:
~torch.hub.set_dir
is not called, default path is$TORCH_HOME/hub
where environment variable$TORCH_HOME
defaults to$XDG_CACHE_HOME/torch
.$XDG_CACHE_HOME
follows the X Design Group specification of the Linux filesystem layout, with a default value~/.cache
if the environment variable is not set.Expand source code
def get_dir(): r""" Get the Torch Hub cache directory used for storing downloaded models & weights. If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux filesystem layout, with a default value ``~/.cache`` if the environment variable is not set. """ # Issue warning to move data if old env is set if os.getenv("TORCH_HUB"): warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") if _hub_dir is not None: return _hub_dir return os.path.join(_get_torch_home(), "hub")
def load_state_dict_from_url(url, model_dir=None, map_location=None)
-
Loads the Torch serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed. If the object is already present in
model_dir
, it's deserialized and returned. The default value ofmodel_dir
is<hub_dir>/checkpoints
wherehub_dir
is the directory returned by :func:~torch.hub.get_dir
.Args
url
:string
- URL of the object to download
model_dir
:string
, optional- directory in which to save the object
map_location
:optional
- a function or a dict specifying how to remap storage locations
Expand source code
def load_state_dict_from_url(url, model_dir=None, map_location=None): r"""Loads the Torch serialized object at the given URL. If downloaded file is a zip file, it will be automatically decompressed. If the object is already present in `model_dir`, it's deserialized and returned. The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. Args: url (string): URL of the object to download model_dir (string, optional): directory in which to save the object map_location (optional): a function or a dict specifying how to remap storage locations """ if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, "checkpoints") os.makedirs(model_dir, exist_ok=True) cached_file = os.path.join(model_dir, url[1]) if not os.path.exists(cached_file): gdown.download(url[0], cached_file, quiet=False) return torch.load(cached_file, map_location=map_location)
def set_dir(d)
-
Optionally set the Torch Hub directory used to save downloaded models & weights.
Args
d
:string
- path to a local folder to save downloaded models & weights.
Expand source code
def set_dir(d): r""" Optionally set the Torch Hub directory used to save downloaded models & weights. Args: d (string): path to a local folder to save downloaded models & weights. """ global _hub_dir _hub_dir = d