Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Constrained Paramaterizations #13

Merged
merged 3 commits into from
Feb 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 36 additions & 34 deletions lucid/misc/gradient_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
or with blocks. Just use a single decorator and everything else will be
handled for you.

If you don't need to serialize your graph and the gradient override isn't
If you don't need to serialize your graph and the gradient override isn't
performance critical, you can use the high level `use_gradient()` decorator:

@use_gradient(_foo_grad)
def foo(x): ...
Otherwise, you can use use the lower level `gradient_override_map()`, a

Otherwise, you can use use the lower level `gradient_override_map()`, a
convenience wrapper for `graph.gradient_override_map()`.
"""

Expand All @@ -48,22 +48,22 @@ def foo(x): ...

def register_to_random_name(grad_f):
"""Register a gradient function to a random string.

In order to use a custom gradient in TensorFlow, it must be registered to a
string. This is both a hassle, and -- because only one function can every be
registered to a string -- annoying to iterate on in an interactive
environemnt.

This function registers a function to a unique random string of the form:

{FUNCTION_NAME}_{RANDOM_SALT}

And then returns the random string. This is a helper in creating more
convenient gradient overrides.

Args:
grad_f: gradient function to register. Should map (op, grad) -> grad(s)

Returns:
String that gradient function was registered to.
"""
Expand All @@ -75,26 +75,26 @@ def register_to_random_name(grad_f):
@contextmanager
def gradient_override_map(override_dict):
"""Convenience wrapper for graph.gradient_override_map().

This functions provides two conveniences over normal tensorflow gradient
overrides: it auomatically uses the default graph instead of you needing to
find the graph, and it automatically
find the graph, and it automatically

Example:

def _foo_grad_alt(op, grad): ...

with gradient_override({"Foo": _foo_grad_alt}):

Args:
override_dict: A dictionary describing how to override the gradient.
keys: strings correponding to the op type that should have their gradient
overriden.
values: functions or strings registered to gradient functions

"""
override_dict_by_name = {}
for op_name, grad_f in override_dict.iteritems():
for (op_name, grad_f) in override_dict.items():
if isinstance(grad_f, str):
override_dict_by_name[op_name] = grad_f
else:
Expand All @@ -105,7 +105,7 @@ def _foo_grad_alt(op, grad): ...

def use_gradient(grad_f):
"""Decorator for easily setting custom gradients for TensorFlow functions.

* DO NOT use this function if you need to serialize your graph.
* This function will cause the decorated function to run slower.

Expand All @@ -127,49 +127,51 @@ def foo(x1, x2, x3): ...

def function_wrapper(f):
def inner(*inputs):

# TensorFlow only supports (as of writing) overriding the gradient of
# individual ops. In order to override the gardient of `f`, we need to
# individual ops. In order to override the gardient of `f`, we need to
# somehow make it appear to be an individual TensorFlow op.
#
# Our solution is to create a PyFunc that mimics `f`.
#
# In particular, we construct a graph for `f` and run it, then use a
# In particular, we construct a graph for `f` and run it, then use a
# stateful PyFunc to stash it's results in Python. Then we have another
# PyFunc mimic it by taking all the same inputs and returning the stashed
# output.
#
# I wish we could do this without PyFunc, but I don't see a way to have
# it be fully general.

state = {"out_value": None}

# First, we need to run `f` and store it's output.

out = f(*inputs)

def store_out(out_value):
"""Store the value of out to a python variable."""
state["out_value"] = out_value
store = tf.py_func(store_out, [out], (), stateful=True,
name = "store_" + f.__name__ )

store_name = "store_" + f.__name__
store = tf.py_func(store_out, [out], (), stateful=True, name=store_name)

# Next, we create the mock function, with an overriden gradient.
# Note that we need to make sure store gets evaluated before the mock
# runs.

def mock_f(*inputs):
"""Mimic f by retrieving the stored value of out."""
return state["out_value"]

with tf.control_dependencies([store]):
with gradient_override({"PyFunc": grad_f_name}):
with gradient_override_map({"PyFunc": grad_f_name}):
mock_name = "mock_" + f.__name__
mock_out = tf.py_func(mock_f, inputs, out.dtype, stateful=True,
name = "mock" + f.__name__ )

name=mock_name)
mock_out.set_shape(out.get_shape())

# Finally, we can return the mock.

return mock_out
return inner
return function_wrapper
4 changes: 3 additions & 1 deletion lucid/optvis/param/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
import tensorflow as tf

from lucid.optvis.param.unit_balls import constrain_L_inf

color_correlation_svd_sqrt = np.asarray([[0.26, 0.09, 0.02],
[0.27, 0.00, -0.05],
[0.27, -0.09, 0.03]]).astype("float32")
Expand Down Expand Up @@ -70,4 +72,4 @@ def to_valid_rgb(t, decorrelate=False, sigmoid=True):
if sigmoid:
return tf.nn.sigmoid(t)
else:
return tf.clip_by_value(t, 0, 1)
return constrain_L_inf(2*t-1)/2 + 0.5
119 changes: 119 additions & 0 deletions lucid/optvis/param/unit_balls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Optimize within unit balls in TensorFlow.

In adverserial examples, one often wants to optize within a constrained ball.
This module makes this easy through functions like unit_ball_L2(), which
creates a tensorflow variable constrained within a L2 unit ball.

EXPERIMENTAL: Do not use for adverserial examples if you need to be confident
they are strong attacks. We are not yet confident in this code.
"""

import tensorflow as tf

from lucid.misc.gradient_override import use_gradient


def dot(a, b):
return tf.reduce_sum(a * b)


def _constrain_L2_grad(op, grad):
"""Gradient for constrained optimization on an L2 unit ball.

This function projects the gradient onto the ball if you are on the boundary
(or outside!), but leaves it untouched if you are inside the ball.

Args:
op: the tensorflow op we're computing the gradient for.
grad: gradient we need to backprop

Returns:
(projected if necessary) gradient.
"""
inp = op.inputs[0]
inp_norm = tf.norm(inp)
unit_inp = inp / inp_norm

grad_projection = dot(unit_inp, grad)
parallel_grad = unit_inp * grad_projection

is_in_ball = tf.less_equal(inp_norm, 1)
is_pointed_inward = tf.less(grad_projection, 0)
allow_grad = tf.logical_or(is_in_ball, is_pointed_inward)
clip_grad = tf.logical_not(allow_grad)

clipped_grad = tf.cond(clip_grad, lambda: grad - parallel_grad, lambda: grad)

return clipped_grad


@use_gradient(_constrain_L2_grad)
def constrain_L2(x):
return x / tf.maximum(1.0, tf.norm(x))


def unit_ball_L2(shape):
"""A tensorflow variable tranfomed to be constrained in a L2 unit ball.

EXPERIMENTAL: Do not use for adverserial examples if you need to be confident
they are strong attacks. We are not yet confident in this code.
"""
x = tf.Variable(tf.zeros(shape))
return constrain_L2(x)


def _constrain_L_inf_grad(precondition=True):

def grad_f(op, grad):
"""Gradient for constrained preconditioned optimization on an L_inf unit
ball.

This function projects the gradient onto the ball if you are on the
boundary (or outside!). It always preconditions the gradient so it is the
direction of steepest descent under L_inf.

Args:
op: the tensorflow op we're computing the gradient for.
grad: gradient we need to backprop

Returns:
(projected if necessary) preconditioned gradient.
"""
inp = op.inputs[0]
dim_at_edge = tf.greater_equal(tf.abs(inp), 1.0)
dim_outward = tf.greater(inp * grad, 0.0)
if precondition:
grad = tf.sign(grad)

return tf.where(
tf.logical_and(dim_at_edge, dim_outward),
tf.zeros(grad.shape),
grad
)
return grad_f


@use_gradient(_constrain_L_inf_grad(precondition=True))
def constrain_L_inf_precondition(x):
return x / tf.maximum(1.0, tf.abs(x))


@use_gradient(_constrain_L_inf_grad(precondition=False))
def constrain_L_inf(x):
return x / tf.maximum(1.0, tf.abs(x))


def unit_ball_L_inf(shape, precondition=True):
"""A tensorflow variable tranfomed to be constrained in a L_inf unit ball.

Note that this code also preconditions the gradient to go in the L_inf
direction of steepest descent.

EXPERIMENTAL: Do not use for adverserial examples if you need to be confident
they are strong attacks. We are not yet confident in this code.
"""
x = tf.Variable(tf.zeros(shape))
if precondition:
return constrain_L_inf_precondition(x)
else:
return constrain_L_inf(x)
46 changes: 46 additions & 0 deletions tests/test_constrained_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import absolute_import, division, print_function

import pytest

import tensorflow as tf
from lucid.optvis.param.unit_balls import unit_ball_L2, unit_ball_L_inf


learning_rate = .1
num_steps = 16


@pytest.mark.parametrize("shape", [(3), (2, 5, 5)])
def test_unit_ball_L2(shape, eps=1e-6):
"""Tests that a L2 unit ball variable's norm stays roughly within 1.0.
Note: only holds down to eps ~= 5e-7.
"""
with tf.Session() as sess:
unit_ball = unit_ball_L2(shape)
unit_ball_L2_norm = tf.norm(unit_ball)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
objective = optimizer.minimize(-unit_ball)
tf.global_variables_initializer().run()
norm_value = unit_ball_L2_norm.eval()
for i in range(num_steps):
_, new_norm_value = sess.run([objective, unit_ball_L2_norm])
assert new_norm_value >= norm_value - eps
assert new_norm_value <= 1.0 + eps
norm_value = new_norm_value


@pytest.mark.parametrize("shape", [(3), (2, 5, 5)])
@pytest.mark.parametrize("precondition", [True, False])
def test_unit_ball_L_inf(shape, precondition, eps=1e-6):
"""Tests that a L infinity unit ball variables' stay roughly within 1.0.
Note: only holds down to eps ~= 5e-7.
"""
with tf.Session() as sess:
unit_ball = unit_ball_L_inf(shape, precondition=precondition)
unit_ball_max = tf.reduce_max(unit_ball)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
objective = optimizer.minimize(-unit_ball)
tf.global_variables_initializer().run()
for i in range(num_steps):
_, unit_ball_max_value = sess.run([objective, unit_ball_max])
assert unit_ball_max_value <= 1.0 + eps
3 changes: 2 additions & 1 deletion tests/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
model = InceptionV1()
model.load_graphdef()


def test_class_logit():
obj = objectives.neuron("mixed4c_pre_relu", 0)
rendering = render.render_vis(model, obj, thresholds=(1, 32), verbose=False)
rendering = render.render_vis(model, obj, thresholds=(1, 4), verbose=False)
start_image = rendering[0]
end_image = rendering[-1]
assert (start_image != end_image).any()