Skip to content

Commit

Permalink
Move io to lucid.misc.io, enable show in render, fix text mode readin…
Browse files Browse the repository at this point in the history
…g by avoiding it, more structure for tests
  • Loading branch information
Ludwig Schubert committed Feb 6, 2018
1 parent 42903b6 commit 148449a
Show file tree
Hide file tree
Showing 32 changed files with 620 additions and 356 deletions.
32 changes: 26 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# Lucid
*DeepDream, but sane. Home of cats, dreams, and interpretable neural networks.*

Lucid is a collection of infrastructure and tools for research in neural network interpretability.
Lucid is a collection of infrastructure and tools for research in neural
network interpretability.

In particular, it provides state of the art implementations of [feature visualization techniques](https://distill.pub/2017/feature-visualization/), and flexible abstractions that make it very easy to explore new research directions.
In particular, it provides state of the art implementations of [feature
visualization techniques](https://distill.pub/2017/feature-visualization/),
and flexible abstractions that make it very easy to explore new research
directions.


# Dive In with Colab Notebooks

Start visualizing neural networks ***with no setup***. The following notebooks run in your browser.
Start visualizing neural networks ***with no setup***. The following notebooks
run in your browser.

**Beginner notebooks**:

* [lucid tutorial]() (TODO) -- Introduction to the core ideas of lucid.
* [lucid tutorial]() (TODO) -- Introduction to the core ideas of lucid.
* [DeepDream]() (TODO) -- Make some dog slugs and crazy art.

**More advanced**:
Expand All @@ -36,11 +41,26 @@ How lucid is structured:
* [**recipes**]():
Less general code that makes a particular visualization.

Note that we do a lot of our research in colab notebooks and transition code here as it matures.
Note that we do a lot of our research in colab notebooks and transition code
here as it matures.


# License and Disclaimer

You may use this software under the Apache 2.0 License. See [LICENSE](LICENSE).

This project is research code. It is not an official Google product.
This project is research code. It is not an official Google product.


# Running tests

Use `tox` to run the test suite on all supported environments.

To run tests only for a specific module, pass a folder to `tox`:
`tox tests/misc/io`

To run tests only in a specific environment, pass the environment's identifier
via the `-e` flag: `tox -e py27`.

After adding dependencies to `setup.py`, run tox with the `--recreate` flag to
update the environments' dependencies.
33 changes: 33 additions & 0 deletions lucid/misc/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2018 The Deepviz Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Ensuring compatibilty across environments, e.g. Jupyter/Colab/Shell."""

from __future__ import absolute_import, division, print_function


def is_notebook_environment():
try:
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # IPython Notebook
elif shell == 'Shell':
return True # Colaboratory Notebook
elif shell == 'TerminalInteractiveShell':
return False # Terminal running IPython
else:
return False # Other unknown type (?)
except NameError:
return False # Probably standard Python interpreter
3 changes: 3 additions & 0 deletions lucid/misc/io/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
These modules' names end on 'ing' so the three main public methods can stay as
simple verbs ('load', 'save', and 'show') without creating namespace conflicts
if you do need to import the module or some of the lower level methods.
3 changes: 3 additions & 0 deletions lucid/misc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from lucid.misc.io.showing import show
from lucid.misc.io.loading import load
from lucid.misc.io.saving import save
29 changes: 19 additions & 10 deletions lucid/util/load.py → lucid/misc/io/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,51 @@
import numpy as np
import PIL.Image

from lucid.util.read import reading
from lucid.misc.io.reading import read_handle


# create logger with module name, e.g. lucid.util.read
# create logger with module name, e.g. lucid.misc.io.reading
log = logging.getLogger(__name__)


def _load_npy(handle):
def _load_npy(handle, **kwargs):
"""Load npy file as numpy array."""
del kwargs
return np.load(handle)


def _load_img(handle):
def _load_img(handle, **kwargs):
"""Load image file as numpy array."""
del kwargs
# PIL.Image will infer image type from provided handle's file extension
pil_img = PIL.Image.open(handle)
# using np.divide should avoid an extra copy compared to doing division first
return np.divide(pil_img, 255, dtype=np.float64)


def _load_json(handle):
def _load_json(handle, encoding=None):
"""Load json file as python object."""
del encoding
return json.load(handle)


def _load_text(handle, encoding='utf-8'):
"""Load and decode a string."""
return handle.read().decode(encoding)

loaders = {
".png": _load_img,
".jpg": _load_img,
".jpeg": _load_img,
".npy": _load_npy,
".npz": _load_npy,
".json": _load_json,
".txt": _load_text,
".md": _load_text,
}


def load(url, cache=None):
def load(url, cache=None, encoding='utf-8'):
"""Load a file.
File format is inferred from url. File retrieval strategy is inferred from
Expand All @@ -84,14 +93,14 @@ def load(url, cache=None):
if ext in loaders:
loader = loaders[ext]
message = "Using inferred loader '%s' due to passed file extension '%s'."
log.info(message, loader.__name__[6:], ext)
with reading(url, cache=cache) as handle:
result = loader(handle)
log.debug(message, loader.__name__[6:], ext)
with read_handle(url, cache=cache) as handle:
result = loader(handle, encoding=encoding)
return result
else:
log.warn("Unknown extension '%s', attempting to load as image.", ext)
try:
with reading(url, cache=cache) as handle:
with read_handle(url, cache=cache) as handle:
result = _load_img(handle)
except Exception as e:
message = "Could not load resource %s as image. Supported extensions: %s"
Expand Down
64 changes: 30 additions & 34 deletions lucid/util/read.py → lucid/misc/io/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

"""Methods for reading bytes from arbitrary sources.
"""Methods for read_handle bytes from arbitrary sources.
This module takes a URL, infers how to locate it,
loads the data into memory and returns it.
Expand All @@ -33,17 +33,17 @@
from tempfile import gettempdir
from io import BytesIO, StringIO

from lucid.util.write import write
from lucid.misc.io.writing import write


# create logger with module name, e.g. lucid.util.read
# create logger with module name, e.g. lucid.misc.io.reading
log = logging.getLogger(__name__)


# Public functions


def read(url, mode='rb', cache=None):
def read(url, encoding=None, cache=None):
"""Read from any URL.
Internally differentiates between URLs supported by tf.gfile, such as URLs
Expand All @@ -52,21 +52,31 @@ def read(url, mode='rb', cache=None):
Args:
url: a URL including scheme or a local path
mode: mode in which to open the file. defaults to binary ('rb')
encoding: if specified, encoding that should be used to decode read data
if mode is specified to be text ('r'), this defaults to 'utf-8'.
cache: whether to attempt caching the resource. Defaults to True only if
the given URL specifies a remote resource.
Returns:
All bytes form the specified resource if it could be reached.
All bytes form the specified resource, or a decoded string of those.
"""
with reading(url, mode, cache) as handle:
return handle.read()
with read_handle(url, cache) as handle:
data = handle.read()

if encoding:
data = data.decode(encoding)

return data


@contextmanager
def reading(url, mode=None, cache=None):
def read_handle(url, cache=None):
"""Read from any URL with a file handle.
Use this to get a handle to a file rather than eagerly load the data:
```
with reading(url) as handle:
with read_handle(url) as handle:
result = something.load(handle)
result.do_something()
Expand All @@ -84,26 +94,19 @@ def reading(url, mode=None, cache=None):
"""
scheme = urlparse(url).scheme

if mode is None:
if _supports_binary_mode(scheme):
mode = 'rb'
else:
mode = 'r'
log.debug("Mode not specified, using '%s'", mode)

if _is_remote(scheme) and cache is None:
cache = True
log.debug("Cache not specified, enabling because resource is remote.")

if cache:
handle = _read_and_cache(url, mode)
handle = _read_and_cache(url)
else:
if scheme in ('http', 'https'):
handle = _handle_web_url(url, mode)
handle = _handle_web_url(url)
elif scheme == 'gs':
handle = _handle_gcs_url(url, mode)
handle = _handle_gcs_url(url)
else:
handle = _handle_gfile(url, mode)
handle = _handle_gfile(url)

yield handle
handle.close()
Expand All @@ -112,48 +115,41 @@ def reading(url, mode=None, cache=None):
# Handlers


def _handle_gfile(url, mode):
def _handle_gfile(url, mode='rb'):
return gfile.Open(url, mode)


def _handle_web_url(url, mode):
def _handle_web_url(url):
del mode # unused
return urlopen(url)


def _handle_gcs_url(url, mode):
def _handle_gcs_url(url):
# TODO: transparently allow authenticated access through storage API
_, resource_name = url.split('://')
base_url = 'https://storage.googleapis.com/'
url = urljoin(base_url, resource_name)
return _handle_web_url(url, mode)
return _handle_web_url(url)


# Helper Functions


def _supports_binary_mode(scheme):
return True


def _is_remote(scheme):
return scheme in ('http', 'https', 'gs')


RESERVED_PATH_CHARS = re.compile("[^a-zA-Z0-9]")


def _read_and_cache(url, mode):
def _read_and_cache(url):
local_name = RESERVED_PATH_CHARS.sub('_', url)
local_path = os.path.join(gettempdir(), local_name)
if os.path.exists(local_path):
log.info("Found cached file '%s'.", local_path)
return _handle_gfile(local_path, mode)
return _handle_gfile(local_path)
else:
log.info("Caching URL '%s' locally at '%s'.", url, local_path)
data = read(url, cache=False) # important to avoid endless loop
write(data, local_path)
if 'b' in mode:
return BytesIO(data)
else:
return StringIO(data)
return BytesIO(data)
19 changes: 10 additions & 9 deletions lucid/util/save.py → lucid/misc/io/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
import os.path
import json
import numpy as np
import PIL.Image

from lucid.util.write import write, writing
from lucid.util.array_to_image import _serialize_array, _normalize_array
from lucid.misc.io.writing import write, write_handle
from lucid.misc.io.serialize_array import _normalize_array


# create logger with module name, e.g. lucid.util.save
# create logger with module name, e.g. lucid.misc.io.saving
log = logging.getLogger(__name__)


Expand All @@ -47,7 +48,7 @@ def save_json(object, url, indent=2):

def save_npy(object, url):
"""Save numpy array as npy file."""
with writing(url, "w") as handle:
with write_handle(url, "w") as handle:
np.save(handle, object)


Expand All @@ -65,12 +66,12 @@ def save_npz(object, url):
def save_img(object, url, **kwargs):
"""Save numpy array as image file on CNS."""
if isinstance(object, np.ndarray):
normalized
image = _normalize_array_and_convert_to_image(object)
else:
normalized = _normalize_array(object)
image = PIL.Image.fromarray(normalized)
elif not isinstance(object, PIL.Image):
raise ValueError("Can only save_img for numpy arrays or PIL.Images!")

with writing(url) as handle:
with write_handle(url) as handle:
image.save(handle, **kwargs) # will infer format from handle's url ext.


Expand Down Expand Up @@ -103,7 +104,7 @@ def save(thing, url, **kwargs):

if ext in savers:
saver = savers[ext]
data = saver(thing, url, **kwargs)
saver(thing, url, **kwargs)
else:
message = "Unknown extension '{}', supports {}."
raise RuntimeError(message.format(ext, loaders))
Loading

0 comments on commit 148449a

Please sign in to comment.