Skip to content

Commit

Permalink
Merge pull request tensorflow#198 from tensorflow/batch-load-save
Browse files Browse the repository at this point in the history
Adding batch load and save to misc.io
  • Loading branch information
gabgoh authored Oct 9, 2019
2 parents 294abf3 + 8a4b1c2 commit 668679e
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 27 deletions.
2 changes: 1 addition & 1 deletion lucid/misc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from lucid.misc.io.showing import show
from lucid.misc.io.loading import load
from lucid.misc.io.saving import save, CaptureSaveContext
from lucid.misc.io.saving import save, CaptureSaveContext, batch_save
from lucid.misc.io.scoping import io_scope, scope_url
15 changes: 10 additions & 5 deletions lucid/misc/io/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
This should support for example PNG images, JSON files, npy files, etc.
"""

from __future__ import absolute_import, division, print_function

import os
import json
import logging
Expand All @@ -33,7 +31,7 @@
from google.protobuf.message import DecodeError

from lucid.misc.io.reading import read_handle
from lucid.misc.io import scoping
from lucid.misc.io.scoping import current_io_scopes, set_io_scopes

# from lucid import modelzoo

Expand All @@ -44,9 +42,15 @@

def _load_urls(urls, cache=None, **kwargs):
pages = {}
caller_io_scopes = current_io_scopes()

def _do_load(url):
set_io_scopes(caller_io_scopes)
return load(url, cache=cache, **kwargs)

with ThreadPoolExecutor(max_workers=8) as executor:
future_to_urls = {
executor.submit(load, url, cache=cache, **kwargs): url for url in urls
executor.submit(_do_load, url): url for url in urls
}
for future in as_completed(future_to_urls):
url = future_to_urls[future]
Expand Down Expand Up @@ -195,7 +199,7 @@ def load_using_loader(url_or_handle, loader, cache, **kwargs):
log.warning(
"While loading '%s' an error occurred. Purging cache once and trying again; if this fails we will raise an Exception! Current io scopes: %r",
url,
scoping.current_io_scopes(),
current_io_scopes(),
)
# since this may have been cached, it's our responsibility to try again once
# since we use a handle here, the next DecodeError should propagate upwards
Expand All @@ -216,3 +220,4 @@ def get_extension(url_or_handle):
if not ext:
raise RuntimeError("No extension in URL: " + url_or_handle)
return ext

27 changes: 20 additions & 7 deletions lucid/misc/io/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,14 @@
Possible extension: if not given a URL this could create one and return it?
"""

from __future__ import absolute_import, division, print_function

import logging
import subprocess
import warnings
import threading
from copy import copy

# from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
import os.path
import json
from typing import Optional, List

from typing import Optional, List, Tuple
import numpy as np
import PIL.Image

Expand Down Expand Up @@ -273,3 +268,21 @@ def save(thing, url_or_handle, save_context: Optional[CaptureSaveContext] = None
result["serve"] = "https://storage.googleapis.com/{}".format(result["url"][5:])

return result


def batch_save(save_ops: List[Tuple], num_workers: int = 16):
caller_io_scopes = current_io_scopes()
current_save_context = CaptureSaveContext.current_save_context()

def _do_save(save_op_tuple: Tuple):
set_io_scopes(caller_io_scopes)
if len(save_op_tuple) == 2:
return save(save_op_tuple[0], save_op_tuple[1], save_context=current_save_context)
elif len(save_op_tuple) == 3:
return save(save_op_tuple[0], save_op_tuple[1], save_context=current_save_context, **(save_op_tuple[2]))
else:
raise ValueError(f'unknown save tuple size: {len(save_op_tuple)}')

with ThreadPoolExecutor(max_workers=num_workers) as executor:
save_op_futures = [executor.submit(_do_save, save_op_tuple) for save_op_tuple in save_ops]
return [future.result() for future in save_op_futures]
28 changes: 21 additions & 7 deletions tests/misc/io/test_loading.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
# -*- coding: UTF-8 -*-
from __future__ import absolute_import, division, print_function

import os

import pytest

import numpy as np
from lucid.misc.io.loading import load
from lucid.misc.io.scoping import io_scope
import io

test_images = [
"./tests/fixtures/rgbeye.png",
"./tests/fixtures/noise_uppercase.PNG",
"./tests/fixtures/rgbeye.jpg",
"./tests/fixtures/noise.jpeg",
"./tests/fixtures/image.xyz",
]

def test_load_json():
path = "./tests/fixtures/dictionary.json"
Expand Down Expand Up @@ -39,13 +49,7 @@ def test_load_npz():
assert isinstance(arrays, np.lib.npyio.NpzFile)


@pytest.mark.parametrize("path", [
"./tests/fixtures/rgbeye.png",
"./tests/fixtures/noise_uppercase.PNG",
"./tests/fixtures/rgbeye.jpg",
"./tests/fixtures/noise.jpeg",
"./tests/fixtures/image.xyz",
])
@pytest.mark.parametrize("path", test_images)
def test_load_image(path):
image = load(path)
assert image.shape is not None
Expand Down Expand Up @@ -79,3 +83,13 @@ def test_load_protobuf():
path = "./tests/fixtures/graphdef.pb"
graphdef = load(path)
assert "int_val: 42" in repr(graphdef)


def test_batch_load():
image_names = [os.path.basename(image) for image in test_images]
with io_scope('./tests/fixtures'):
images = load(image_names)
assert len(images) == len(test_images)
for i in range(len(test_images)):
assert np.allclose(load(test_images[i]), images[i])

28 changes: 21 additions & 7 deletions tests/misc/io/test_saving.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from __future__ import absolute_import, division, print_function

import time

import pytest

import numpy as np
from lucid.misc.io.saving import save, CaptureSaveContext
from lucid.misc.io.saving import save, CaptureSaveContext, batch_save
from lucid.misc.io.scoping import io_scope, current_io_scopes
from concurrent.futures import ThreadPoolExecutor
import os.path
Expand Down Expand Up @@ -135,8 +131,10 @@ def test_capturing_saves():
path = "./tests/fixtures/generated_outputs/test_capturing_saves.txt"
_remove(path)
context = CaptureSaveContext()
with context:
save("test", path)

with context, io_scope("./tests/fixtures/generated_outputs"):
save("test", "test_capturing_saves.txt")

captured = context.captured_saves
assert len(captured) == 1
assert "type" in captured[0]
Expand All @@ -156,3 +154,19 @@ def _return_io_scope(io_scope_path):
futures = {executor.submit(_return_io_scope, f'gs://test-{i}'): f'gs://test-{i}' for i in range(n_tasks)}
results = [f.result() for f in futures]
assert results == list(futures.values())


def test_batch_saves():
save_ops = [(str(i), f"write_batch_{i}.txt") for i in range(5)]
[_remove(f"./tests/fixtures/generated_outputs/write_batch_{i}.txt") for i in range(5)]

context = CaptureSaveContext()
with context, io_scope("./tests/fixtures/generated_outputs"):
results = batch_save(save_ops)
assert len(results) == 5

assert len(context.captured_saves) == 5
assert context.captured_saves[0]['type'] == 'txt'
print(context.captured_saves)
assert 'write_batch_' in context.captured_saves[0]['url']
assert all([os.path.isfile(f"./tests/fixtures/generated_outputs/write_batch_{i}.txt") for i in range(5)])

0 comments on commit 668679e

Please sign in to comment.