Skip to content

Commit

Permalink
better tensorboard (#7)
Browse files Browse the repository at this point in the history
* better tensorboard

* tweaks

* format

* better logging

* image logging tests

* labeled images test
  • Loading branch information
codekansas authored Feb 15, 2024
1 parent 975b2ad commit 2296feb
Show file tree
Hide file tree
Showing 9 changed files with 490 additions and 72 deletions.
12 changes: 3 additions & 9 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
from dpshdl.impl.mnist import MNIST
from jaxtyping import Array, Float, Int
from PIL import Image

import xax

Expand Down Expand Up @@ -88,13 +86,9 @@ def compute_loss(self, model: Model, batch: Batch, output: Yhatb, state: xax.Sta

def log_valid_step(self, model: Model, batch: Batch, output: Yhatb, state: xax.State) -> None:
max_images = 16
# (x, y), yhat = batch, output.argmax(axis=1)
# labels = [f"pred: {p.item()}, true: {t.item()}" for p, t in zip(yhat[:max_images], y[:max_images])]
# self.log_labeled_images("predictions", (x, labels), max_images=max_images, sep=2)
x, _ = batch
xnp = np.array(jax.device_get(x[:max_images]))
images = [Image.fromarray((xnp[i] * 255).astype(jnp.uint8)) for i in range(max_images)]
self.log_images("images", images, max_images=max_images, sep=2)
(x, y), yhat = batch, output.argmax(axis=1)
labels = [f"pred: {p.item()}\ntrue: {t.item()}" for p, t in zip(yhat[:max_images], y[:max_images])]
self.logger.log_labeled_images("predictions", (x, labels), max_images=max_images)

def get_dataset(self, phase: xax.Phase) -> MNIST:
return MNIST(
Expand Down
104 changes: 104 additions & 0 deletions tests/task/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Runs tests on the logger module."""

import jax.numpy as jnp
import numpy as np
import pytest
from jaxtyping import Array
from PIL import Image
from PIL.Image import Image as PILImage

import xax


class DummyLogger(xax.LoggerImpl):
def __init__(self) -> None:
super().__init__()

self._line: xax.LogLine | None = None

@property
def line(self) -> xax.LogLine:
assert self._line is not None
return self._line

def write(self, line: xax.LogLine) -> None:
self._line = line

def clear(self) -> None:
self._line = None

def should_log(self, state: xax.State) -> bool:
return True


@pytest.mark.parametrize(
"image",
[
np.random.random((32, 32, 3)),
np.random.random((32, 32, 1)),
np.random.random((3, 32, 32)),
np.random.random((32, 32)),
jnp.array(np.random.random((32, 32, 3))),
jnp.array(np.random.random((1, 32, 32))),
Image.new("RGB", (32, 32)),
Image.new("L", (32, 32)),
np.array(Image.new("L", (32, 32))),
],
)
def test_log_image(image: np.ndarray | Array | PILImage) -> None:
with xax.Logger() as logger:
dummy_logger = DummyLogger()
logger.add_logger(dummy_logger)

# Logs the image.
logger.log_image("test", image, target_resolution=(32, 32))
logger.write(xax.State.init_state())
image = dummy_logger.line.images["value"]["test"].image
dummy_logger.clear()
assert image.size == (32, 32)

# Logs the image with a caption.
logger.log_labeled_image("test", (image, "caption\ncaption"), target_resolution=(32, 32))
logger.write(xax.State.init_state())
image = dummy_logger.line.images["value"]["test"].image
dummy_logger.clear()
assert image.size > (32, 32)


@pytest.mark.parametrize(
"images",
[
np.random.random((7, 32, 32, 3)),
np.random.random((7, 32, 32, 1)),
np.random.random((7, 3, 32, 32)),
np.random.random((7, 32, 32)),
jnp.array(np.random.random((7, 32, 32, 3))),
jnp.array(np.random.random((7, 1, 32, 32))),
[Image.new("RGB", (32, 32))] * 7,
[Image.new("L", (32, 32))] * 7,
np.array(Image.new("L", (32, 32)))[None].repeat(7, axis=0),
],
)
def test_log_images(images: np.ndarray | Array | list[PILImage]) -> None:
with xax.Logger() as logger:
dummy_logger = DummyLogger()
logger.add_logger(dummy_logger)

# Logs the images.
logger.log_images("test", images, target_resolution=(32, 32), max_images=6)
logger.write(xax.State.init_state())
image = dummy_logger.line.images["value"]["test"].image
dummy_logger.clear()
assert np.prod(image.size) == 6 * 32 * 32

# Logs the images with captions.
logger.log_labeled_images(
"test",
(images, ["caption\ncaption"] * 7),
target_resolution=(32, 32),
max_images=6,
)
logger.write(xax.State.init_state())
image = dummy_logger.line.images["value"]["test"].image
dummy_logger.clear()
assert np.prod(image.size) > 6 * 32 * 32
3 changes: 3 additions & 0 deletions xax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"LogLine",
"Logger",
"LoggerImpl",
"CallbackLogger",
"JsonLogger",
"StateLogger",
"StdoutLogger",
Expand Down Expand Up @@ -103,6 +104,7 @@
"LogLine": "task.logger",
"Logger": "task.logger",
"LoggerImpl": "task.logger",
"CallbackLogger": "task.loggers.callback",
"JsonLogger": "task.loggers.json",
"StateLogger": "task.loggers.state",
"StdoutLogger": "task.loggers.stdout",
Expand Down Expand Up @@ -177,6 +179,7 @@ def __getattr__(name: str) -> object:
from xax.task.launchers.cli import CliLauncher
from xax.task.launchers.single_process import SingleProcessLauncher
from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
from xax.task.loggers.callback import CallbackLogger
from xax.task.loggers.json import JsonLogger
from xax.task.loggers.state import StateLogger
from xax.task.loggers.stdout import StdoutLogger
Expand Down
Loading

0 comments on commit 2296feb

Please sign in to comment.