Module earthvision.datasets.utils

Utility functions.

Expand source code
"""Utility functions."""
import sys
import os
import urllib
import collections
import ssl
import numpy as np
import boto3
from spectral import open_image

from tqdm import tqdm
from PIL import Image
from botocore import UNSIGNED
from botocore.client import Config

# define MAX_IMAGE_PIXELS
Image.MAX_IMAGE_PIXELS = 1000000000


def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
    ctx = ssl.create_default_context()
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE

    with open(filename, "wb") as fh:
        with urllib.request.urlopen(urllib.request.Request(url), context=ctx) as response:
            with tqdm(total=response.length) as pbar:
                for chunk in iter(lambda: response.read(chunk_size), ""):
                    if not chunk:
                        break
                    pbar.update(chunk_size)
                    fh.write(chunk)


def s3_downloader(s3_client, local_file_name: str, s3_bucket: str, s3_object_key: str):
    """Download dataset from Amazon S3.

    Args:
        s3_client: Object boto3.client.
        local_file_name: Destination filepath.
        s3_bucket: S3 bucket.
        s3_object_key: S3 object key.

    """
    meta_data = s3_client.head_object(Bucket=s3_bucket, Key=s3_object_key)
    total_length = int(meta_data.get("ContentLength", 0))
    downloaded = 0

    def progress(chunk):
        nonlocal downloaded
        downloaded += chunk
        done = int(50 * downloaded / total_length)
        sys.stdout.write("\r[%s%s]" % ("=" * done, " " * (50 - done)))
        sys.stdout.flush()

    print(f"Downloading {s3_object_key}")
    with open(local_file_name, "wb") as f:
        s3_client.download_fileobj(s3_bucket, s3_object_key, f, Callback=progress)


def downloader(resource: str, root: str):
    """Downloader function that handle general download link or S3 cloud storage.

    Args:
        resource: Dataset resource link.
        root: Dataset destination filepath.

    """
    resource_type, obj = resource.split("://")[0], resource.split("://")[1]
    dest_pth = os.path.join(root, resource.split("/")[-1])

    if resource_type == "s3":
        s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
        bucket = obj.split("/")[0]
        obj_key = "/".join(obj.split("/")[1:])
        s3_downloader(s3_client, dest_pth, bucket, obj_key)
    else:
        _urlretrieve(resource, dest_pth)


def _load_img(fname):
    return Image.open(fname)


def _load_npy(fname):
    return np.load(fname)


def _load_img_hdr(fname):
    return open_image(fname).read_band(0)


def _resize_stack(ls):
    ls_size = [im.size for im in ls]

    h, w = zip(*ls_size)

    h_mode = list(collections.Counter(h))[0]
    w_mode = list(collections.Counter(w))[0]

    for idx, (h, w) in enumerate(ls_size):
        if (h != h_mode) | ((w != w_mode)):
            ls[idx] = ls[idx].resize((h_mode, w_mode))

    return ls


def _load_stack_img(list_path_file):
    ls = [Image.open(file_name) for file_name in list_path_file]

    ls = _resize_stack(ls)
    stack_img = np.stack(ls)
    stack_img = stack_img.astype(np.int16)
    return stack_img

Functions

def downloader(resource: str, root: str)

Downloader function that handle general download link or S3 cloud storage.

Args

resource
Dataset resource link.
root
Dataset destination filepath.
Expand source code
def downloader(resource: str, root: str):
    """Downloader function that handle general download link or S3 cloud storage.

    Args:
        resource: Dataset resource link.
        root: Dataset destination filepath.

    """
    resource_type, obj = resource.split("://")[0], resource.split("://")[1]
    dest_pth = os.path.join(root, resource.split("/")[-1])

    if resource_type == "s3":
        s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
        bucket = obj.split("/")[0]
        obj_key = "/".join(obj.split("/")[1:])
        s3_downloader(s3_client, dest_pth, bucket, obj_key)
    else:
        _urlretrieve(resource, dest_pth)
def s3_downloader(s3_client, local_file_name: str, s3_bucket: str, s3_object_key: str)

Download dataset from Amazon S3.

Args

s3_client
Object boto3.client.
local_file_name
Destination filepath.
s3_bucket
S3 bucket.
s3_object_key
S3 object key.
Expand source code
def s3_downloader(s3_client, local_file_name: str, s3_bucket: str, s3_object_key: str):
    """Download dataset from Amazon S3.

    Args:
        s3_client: Object boto3.client.
        local_file_name: Destination filepath.
        s3_bucket: S3 bucket.
        s3_object_key: S3 object key.

    """
    meta_data = s3_client.head_object(Bucket=s3_bucket, Key=s3_object_key)
    total_length = int(meta_data.get("ContentLength", 0))
    downloaded = 0

    def progress(chunk):
        nonlocal downloaded
        downloaded += chunk
        done = int(50 * downloaded / total_length)
        sys.stdout.write("\r[%s%s]" % ("=" * done, " " * (50 - done)))
        sys.stdout.flush()

    print(f"Downloading {s3_object_key}")
    with open(local_file_name, "wb") as f:
        s3_client.download_fileobj(s3_bucket, s3_object_key, f, Callback=progress)