Skip to content

Commit

Permalink
Merge pull request #3 from Katsumata420/introduce-hf-datasets
Browse files Browse the repository at this point in the history
Added cache_dir to Args
  • Loading branch information
Katsumata420 authored Dec 15, 2022
2 parents ddcd365 + 36ce7a1 commit 92b1f02
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
5 changes: 4 additions & 1 deletion blink/biencoder/train_biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,10 @@ def main(params):

# Load train data
train_samples = load_dataset(
"json", data_files={"train": os.path.join(params["data_path"], "train.jsonl")}, streaming=False
"json",
data_files={"train": os.path.join(params["data_path"], "train.jsonl")},
streaming=False,
cache_dir=params["cache_dir"],
)["train"]
logger.info("Read %d train samples." % len(train_samples))

Expand Down
6 changes: 6 additions & 0 deletions blink/common/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ def add_model_args(self, args=None):
required=True,
help="The output directory where generated output file (model, etc.) is to be dumped.",
)
parser.add_argument(
"--cache_dir",
default=None,
type=str,
help="dataset cache dir. Default path may be /home/user/.cache/...",
)


def add_training_args(self, args=None):
Expand Down

0 comments on commit 92b1f02

Please sign in to comment.