Skip to content

Commit

Permalink
Fix: Create save dir
Browse files Browse the repository at this point in the history
  • Loading branch information
Jahn Heymann committed Jan 18, 2016
1 parent 0fae0df commit 14e917b
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from nn_models import BLSTMMaskEstimator
from nn_models import SimpleFWMaskEstimator
from fgnt.utils import Timer
from fgnt.utils import mkdir_p

parser = argparse.ArgumentParser(description='NN GEV training')
parser.add_argument('data_dir', help='Directory used for the training data '
Expand Down Expand Up @@ -68,11 +69,14 @@
if args.model_type == 'BLSTM':
model = BLSTMMaskEstimator()
model_save_dir = os.path.join(args.data_dir, 'BLSTM_model')
mkdir_p(model_save_dir)
elif args.model_type == 'FW':
model = SimpleFWMaskEstimator()
model_save_dir = os.path.join(args.data_dir, 'FW_model')
mkdir_p(model_save_dir)
else:
raise ValueError('Unknown model type. Possible are "BLSTM" and "FW"')

if args.gpu >= 0:
cuda.get_device(args.gpu).use()
model.to_gpu()
Expand Down

0 comments on commit 14e917b

Please sign in to comment.