Skip to content

Commit

Permalink
hotfix for loss type not being display properly
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Aug 1, 2018
1 parent 7faddcc commit a528c49
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
9 changes: 4 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@
help='Enable use of multiple camera')
parser.add_argument('--balanced-sampling', action='store_true', default=False,
help='Force balanced sampling for episode independent prior instead of uniform')
parser.add_argument('--losses', nargs='+', default=["priors"], type=
loss_argument(
choices=["forward", "inverse", "reward", "priors", "episode-prior", "reward-prior", "triplet",
"autoencoder", "vae", "perceptual","dae"]),
help='The wanted losses. Can also impose weight for every defined loss: "<name>:<weight>".')
parser.add_argument('--losses', nargs='+', default=["priors"], **loss_argument(
choices=["forward", "inverse", "reward", "priors", "episode-prior", "reward-prior", "triplet",
"autoencoder", "vae", "perceptual", "dae"],
help='The wanted losses. Can also impose weight for every defined loss: "<name>:<weight>".'))
parser.add_argument('--beta', type=float, default=1.0,
help='(For beta-VAE only) Factor on the KL divergence, higher value means more disentangling.')
parser.add_argument('--split-index', type=int, default=-1,
Expand Down
26 changes: 17 additions & 9 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,35 @@
import argparse


def loss_argument(choices):
def loss_argument(choices, help):
"""
Creates a custom type for loss parsing
Creates a custom type for loss parsing, it overrides the type, choice and help of add_argument, in order to
properly extract the loss type, and still be able to print the choices available.
:param choices: ([str]) the list of valid losses
:return: (function (str): ((str, float) or str))
:return: (dict) the arguments for parse arg
"""
def _arg(arg):
def _arg_type(arg):
has_weight = ':' in arg
if has_weight:
if arg.split(':')[0] not in choices:
raise argparse.ArgumentError("invalid choice: {} (choose from {})".format(arg.split(':')[0], choices))
raise argparse.ArgumentTypeError("invalid choice: {} (choose from {})".format(arg.split(':')[0], choices))
try:
return (arg.split(':')[0], float(arg.split(':')[1]))
return arg.split(':')[0], float(arg.split(':')[1])
except ValueError:
raise argparse.ArgumentError("Error: must be of format '<str>:<float>' or '<str>'")
raise argparse.ArgumentTypeError("Error: must be of format '<str>:<float>' or '<str>'")
else:
if arg not in choices:
raise argparse.ArgumentError("invalid choice: {} (choose from {})".format(arg, choices))
raise argparse.ArgumentTypeError("invalid choice: {} (choose from {})".format(arg, choices))
return arg
return _arg

def _choices_print():
str_out = "{"
for loss in choices[:-1]:
str_out += loss + ", "
return str_out + choices[-1] + '}'

return {'type': _arg_type, 'help': _choices_print() + " " + help}

def buildConfig(args):
"""
Expand Down

0 comments on commit a528c49

Please sign in to comment.