Skip to content

Commit

Permalink
Make this file TensorFlow 2.0 compatible (tensorflow#7836)
Browse files Browse the repository at this point in the history
* Make this file TensorFlow 2.0 compatible

* Update remove_gt_colormap.py
  • Loading branch information
Lotte1990 authored and aquariusjay committed Nov 21, 2019
1 parent 7e12e4f commit 92384c6
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions research/deeplab/datasets/remove_gt_colormap.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@

import tensorflow as tf

FLAGS = tf.app.flags.FLAGS
FLAGS = tf.compat.v1.flags.FLAGS

tf.app.flags.DEFINE_string('original_gt_folder',
'./VOCdevkit/VOC2012/SegmentationClass',
'Original ground truth annotations.')
tf.compat.v1.flags.DEFINE_string('original_gt_folder',
'./VOCdevkit/VOC2012/SegmentationClass',
'Original ground truth annotations.')

tf.app.flags.DEFINE_string('segmentation_format', 'png', 'Segmentation format.')
tf.compat.v1.flags.DEFINE_string('segmentation_format', 'png', 'Segmentation format.')

tf.app.flags.DEFINE_string('output_dir',
'./VOCdevkit/VOC2012/SegmentationClassRaw',
'folder to save modified ground truth annotations.')
tf.compat.v1.flags.DEFINE_string('output_dir',
'./VOCdevkit/VOC2012/SegmentationClassRaw',
'folder to save modified ground truth annotations.')


def _remove_colormap(filename):
Expand All @@ -59,14 +59,14 @@ def _save_annotation(annotation, filename):
filename: Output filename.
"""
pil_image = Image.fromarray(annotation.astype(dtype=np.uint8))
with tf.gfile.Open(filename, mode='w') as f:
with tf.io.gfile.GFile(filename, mode='w') as f:
pil_image.save(f, 'PNG')


def main(unused_argv):
# Create the output directory if not exists.
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
if not tf.io.gfile.isdir(FLAGS.output_dir):
tf.io.gfile.makedirs(FLAGS.output_dir)

annotations = glob.glob(os.path.join(FLAGS.original_gt_folder,
'*.' + FLAGS.segmentation_format))
Expand All @@ -80,4 +80,4 @@ def main(unused_argv):


if __name__ == '__main__':
tf.app.run()
tf.compat.v1.app.run()

0 comments on commit 92384c6

Please sign in to comment.