Skip to content

Commit

Permalink
fix hyper param error
Browse files Browse the repository at this point in the history
  • Loading branch information
ruhyadi committed Mar 10, 2022
1 parent 5902311 commit 6ee791e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import os
from random import shuffle
import sys
from pathlib import Path

Expand Down Expand Up @@ -75,7 +76,11 @@ def train(
experiment.log_parameters(hyper_params)

# data generator
data_gen = data.DataLoader(dataset, **hyper_params)
data_gen = data.DataLoader(
dataset,
batch_size=hyper_params['batch_size'],
shuffle=hyper_params['shuffle'],
num_workers=hyper_params['num_workers'])

# model
base_model = model_factory[select_model]
Expand Down

0 comments on commit 6ee791e

Please sign in to comment.