Skip to content

Commit

Permalink
Merge pull request FederatedAI#4225 from FederatedAI/bugfix/bump-tens…
Browse files Browse the repository at this point in the history
…orflow-version

refactor: upgrade tensorflow to 2.x
  • Loading branch information
dylan-fan authored Aug 18, 2022
2 parents 985dbe4 + c525d13 commit 4c4e112
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 210 deletions.
176 changes: 48 additions & 128 deletions python/federatedml/nn/backend/tf_keras/nn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,13 @@
#
import copy
import io
import json
import os
import tempfile
import zipfile
import uuid
import zipfile

import numpy as np
import tensorflow as tf
from tensorflow.keras.backend import gradients
from tensorflow.keras.callbacks import History
from tensorflow.python.keras import backend

try:
from tensorflow import (
get_default_graph,
global_variables,
initialize_variables,
report_uninitialized_variables,
assign,
placeholder,
)
from tensorflow.keras.backend import set_session
except ImportError:
from tensorflow.compat.v1 import (
get_default_graph,
report_uninitialized_variables,
global_variables,
initialize_variables,
assign,
placeholder,
)
from tensorflow.compat.v1.keras.backend import set_session

tf.compat.v1.disable_eager_execution()

from federatedml.framework.weights import OrderDictWeights, Weights
from federatedml.nn.backend.tf_keras import losses
from federatedml.nn.homo_nn.nn_model import DataConverter, NNModel
Expand All @@ -70,17 +43,7 @@ def _zip_dir_as_bytes(path):
return zip_bytes


def _init_session():
sess = backend.get_session()
get_default_graph()
set_session(sess)
return sess


def _modify_model_input_shape(nn_struct, input_shape):
import copy
import json

if not input_shape:
return json.dumps(nn_struct)

Expand All @@ -99,15 +62,15 @@ def _modify_model_input_shape(nn_struct, input_shape):

if struct["config"]["layers"][0].get("config"):
struct["config"]["layers"][0]["config"]["batch_input_shape"] = [
None
] + input_shape
None,
*input_shape,
]
return json.dumps(struct)
else:
return json.dump(struct)


def build_keras(nn_define, loss, optimizer, metrics, **kwargs):
_init_session()
nn_define_json = _modify_model_input_shape(
nn_define, kwargs.get("input_shape", None)
)
Expand All @@ -119,46 +82,32 @@ def build_keras(nn_define, loss, optimizer, metrics, **kwargs):

class KerasNNModel(NNModel):
def __init__(self, model):
self._sess: tf.Session = _init_session()
self._model: tf.keras.Sequential = model
self._trainable_weights = {
v.name: v for v in self._model.trainable_weights
}
self._trainable_weights = {v.name: v for v in self._model.trainable_weights}

self._initialize_variables()
self._loss = None
self._loss_fn = None

def compile(self, loss, optimizer, metrics):
optimizer_instance = getattr(tf.keras.optimizers, optimizer.optimizer)(
**optimizer.kwargs
)
loss_fn = getattr(losses, loss)
self._model.compile(optimizer=optimizer_instance, loss=loss_fn, metrics=metrics)

def _initialize_variables(self):
uninitialized_var_names = [
bytes.decode(var)
for var in self._sess.run(report_uninitialized_variables())
]
uninitialized_vars = [
var
for var in global_variables()
if var.name.split(":")[0] in uninitialized_var_names
]
self._sess.run(initialize_variables(uninitialized_vars))
self.loss_fn = getattr(losses, loss)
self._model.compile(
optimizer=optimizer_instance, loss=self.loss_fn, metrics=metrics
)

@staticmethod
def _trim_device_str(name):
return name.split("/")[0]

def get_model_weights(self) -> OrderDictWeights:
return OrderDictWeights(self._sess.run(self._trainable_weights))
return OrderDictWeights(self._trainable_weights)

def set_model_weights(self, weights: Weights):
unboxed = weights.unboxed
self._sess.run(
[assign(v, unboxed[name]) for name, v in self._trainable_weights.items()]
)
for name, v in self._trainable_weights.items():
v.assign(unboxed[name])

def get_layer_by_index(self, layer_idx):
return self._model.layers[layer_idx]
Expand All @@ -167,67 +116,61 @@ def set_layer_weights_by_index(self, layer_idx, weights):
self._model.layers[layer_idx].set_weights(weights)

def get_input_gradients(self, X, y):
return self._get_gradients(X, y, self._model.input)
with tf.GradientTape() as tape:
X = tf.constant(X)
y = tf.constant(y)
tape.watch(X)
loss = self._model.compiled_loss(y, self._model(X))
return [tape.gradient(loss, X).numpy()]

def get_trainable_gradients(self, X, y):
return self._get_gradients(X, y, self._trainable_weights)

def derivative_of_output_wrt_weights(self, X):
gradient = gradients(self._model.output, self._model.trainable_variables)
return self._sess.run(gradient, feed_dict={self._model.input: X})

def apply_gradients(self, grads):
update_ops = self._model.optimizer.apply_gradients(
self._model.optimizer.apply_gradients(
zip(grads, self._model.trainable_variables)
)
self._initialize_variables()
self._sess.run(update_ops)

def get_weight_gradients(self, X, y):
return self._get_gradients(X, y, self._model.trainable_variables)

def get_trainable_weights(self):
return self._sess.run(self._model.trainable_variables)
return [w.numpy() for w in self._model.trainable_weights]

def get_loss(self):
return self._loss

def get_forward_loss_from_input(self, X, y):
from federatedml.nn.hetero_nn.backend.tf_keras import losses

y_true = placeholder(
shape=self._model.output.shape, dtype=self._model.output.dtype
)

loss_fn = getattr(losses, self._model.loss_functions[0].fn.__name__)(
y_true, self._model.output
)
return self._sess.run(loss_fn, feed_dict={self._model.input: X, y_true: y})
# losses_fn = self._model.compiled_loss._losses
# LOGGER.error(f"{losses_fn}-{type(losses_fn)}")
# if isinstance(losses_fn, list):
# losses_fn = losses_fn[0]
# if isinstance(losses_fn, tf.keras.losses.LossFunctionWrapper):
# losses_fn = losses_fn.fn
# LOGGER.error(f"{losses_fn}-{type(losses_fn)}")
loss = self.loss_fn(tf.constant(y), self._model(X))
return loss.numpy()

def _get_gradients(self, X, y, variable):
from federatedml.nn.hetero_nn.backend.tf_keras import losses

y_true = placeholder(
shape=self._model.output.shape, dtype=self._model.output.dtype
)

loss_fn = getattr(losses, self._model.loss_functions[0].fn.__name__)(
y_true, self._model.output
)
gradient = gradients(loss_fn, variable)
return self._sess.run(gradient, feed_dict={self._model.input: X, y_true: y})
with tf.GradientTape() as tape:
y = tf.constant(y)
loss = self._model.compiled_loss(y, self._model(X))
g = tape.gradient(loss, variable)
if isinstance(g, list):
return [t.numpy() for t in g]
else:
return [g.numpy()]

def set_learning_rate(self, learning_rate):
assign_op = assign(self._model.optimizer.learning_rate, learning_rate)
self._sess.run(assign_op)
self._model.optimizer.learning_rate.assign(learning_rate)

def train(self, data: tf.keras.utils.Sequence, **kwargs):
epochs = 1
left_kwargs = copy.deepcopy(kwargs)
if "aggregate_every_n_epoch" in kwargs:
epochs = kwargs["aggregate_every_n_epoch"]
del left_kwargs["aggregate_every_n_epoch"]
left_kwargs["callbacks"] = [History()]
left_kwargs["callbacks"] = [tf.keras.callbacks.History()]
self._model.fit(x=data, epochs=epochs, verbose=1, shuffle=True, **left_kwargs)
self._loss = left_kwargs["callbacks"][0].history["loss"]
return epochs * len(data)
Expand All @@ -248,21 +191,8 @@ def export_model(self):
os.mkdir(model_base)
model_path = f"{model_base}/{uuid.uuid1()}"
os.mkdir(model_path)
try:
from tensorflow.keras.experimental import (
export_saved_model as save_model,
)

save_model(self._model, model_path)

except ImportError:
from tensorflow.compat.v1 import saved_model

saved_model.save(self._model, model_path)
# from tensorflow.keras.models import save_model

self._model.save(model_path)
model_bytes = _zip_dir_as_bytes(model_path)

return model_bytes

@staticmethod
Expand All @@ -278,24 +208,14 @@ def restore_model(
with zipfile.ZipFile(bytes_io, "r", zipfile.ZIP_DEFLATED) as f:
f.extractall(model_path)

try:
from tensorflow.keras.models import load_model

# add custom objects
from federatedml.nn.hetero_nn.backend.tf_keras.losses import (
keep_predict_loss,
)
# add custom objects
from federatedml.nn.hetero_nn.backend.tf_keras.losses import keep_predict_loss

tf.keras.utils.get_custom_objects().update(
{"keep_predict_loss": keep_predict_loss}
)

except ImportError:
from tensorflow.keras.experimental import (
load_from_saved_model as load_model,
)
tf.keras.utils.get_custom_objects().update(
{"keep_predict_loss": keep_predict_loss}
)

model = load_model(f"{model_path}")
model = tf.keras.models.load_model(f"{model_path}")

return KerasNNModel(model)

Expand Down
Loading

0 comments on commit 4c4e112

Please sign in to comment.