Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 321883870
  • Loading branch information
saberkun authored and tensorflower-gardener committed Jul 18, 2020
1 parent abd09bd commit 1759f3e
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 47 deletions.
2 changes: 2 additions & 0 deletions orbit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Orbit package definition."""

from orbit import utils
from orbit.controller import Controller
Expand Down
11 changes: 3 additions & 8 deletions orbit/controller.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,14 +15,8 @@
# ==============================================================================
"""A light weight utilities to train TF2 models."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import time
from typing import Callable, Optional, Text, Union

from absl import logging
from orbit import runner
from orbit import utils
Expand All @@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name, interval, steps_per_loop))


class Controller(object):
class Controller:
"""Class that facilitates training and evaluation of models."""

def __init__(
Expand Down Expand Up @@ -396,7 +391,7 @@ def _maybe_save_checkpoint(self, force_trigger: bool = False):
return False


class StepTimer(object):
class StepTimer:
"""Utility class for measuring steps/second."""

def __init__(self, step):
Expand Down
7 changes: 2 additions & 5 deletions orbit/controller_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,10 +15,6 @@
# ==============================================================================
"""Tests for orbit.controller."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from absl import logging
from absl.testing import parameterized
Expand Down Expand Up @@ -203,7 +200,7 @@ def _replicated_step(inputs):
class ControllerTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super(ControllerTest, self).setUp()
super().setUp()
self.model_dir = self.get_temp_dir()

def test_no_checkpoint(self):
Expand Down
13 changes: 3 additions & 10 deletions orbit/runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,19 +15,12 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import abc
from typing import Dict, Optional, Text
import six
import tensorflow as tf


@six.add_metaclass(abc.ABCMeta)
class AbstractTrainer(tf.Module):
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the APIs required for training."""

@abc.abstractmethod
Expand Down Expand Up @@ -56,8 +50,7 @@ def train(self,
pass


@six.add_metaclass(abc.ABCMeta)
class AbstractEvaluator(tf.Module):
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the APIs required for evaluation."""

@abc.abstractmethod
Expand Down
13 changes: 3 additions & 10 deletions orbit/standard_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,21 +15,14 @@
# ==============================================================================
"""An abstraction that users can easily handle their custom training loops."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import abc
from typing import Any, Dict, Optional, Text
from orbit import runner
from orbit import utils
import six
import tensorflow as tf


@six.add_metaclass(abc.ABCMeta)
class StandardTrainer(runner.AbstractTrainer):
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs."""

def __init__(self,
Expand Down Expand Up @@ -145,8 +139,7 @@ def train_dataset(self, train_dataset):
self._train_iter = None


@six.add_metaclass(abc.ABCMeta)
class StandardEvaluator(runner.AbstractEvaluator):
class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs."""

def __init__(self, eval_dataset, use_tf_function=True):
Expand Down
1 change: 1 addition & 0 deletions orbit/standard_runner_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
19 changes: 5 additions & 14 deletions orbit/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,18 +15,12 @@
# ==============================================================================
"""Some layered modules/functions to help users writing custom training loop."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import abc
import contextlib
import functools
import inspect

import numpy as np
import six
import tensorflow as tf


Expand Down Expand Up @@ -132,10 +127,7 @@ def dataset_fn(ctx):
# names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`.
if six.PY3:
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn) # pylint: disable=deprecated-method
argspec = inspect.getfullargspec(dataset_or_fn)
args_names = argspec.args

if "input_context" in args_names:
Expand All @@ -146,7 +138,7 @@ def dataset_fn(ctx):
return strategy.experimental_distribute_datasets_from_function(dataset_fn)


class SummaryManager(object):
class SummaryManager:
"""A class manages writing summaries."""

def __init__(self, summary_dir, summary_fn, global_step=None):
Expand Down Expand Up @@ -201,8 +193,7 @@ def write_summaries(self, items):
self._summary_fn(name, tensor, step=self._global_step)


@six.add_metaclass(abc.ABCMeta)
class Trigger(object):
class Trigger(metaclass=abc.ABCMeta):
"""An abstract class representing a "trigger" for some event."""

@abc.abstractmethod
Expand Down Expand Up @@ -263,7 +254,7 @@ def reset(self):
self._last_trigger_value = 0


class EpochHelper(object):
class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop."""

def __init__(self, epoch_steps, global_step):
Expand Down

0 comments on commit 1759f3e

Please sign in to comment.