-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path_0_init_data.py
66 lines (56 loc) · 1.68 KB
/
_0_init_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Download data from HuggingFace to local machine
"""
# 1. Imports -------------------------------------------------------------------
# native imports
import os
import sys
import argparse
# sys.path.insert(0, os.path.join("/home", "niche", "pyniche"))
# custom imports
from pyniche.data.yolo.API import YOLO_API
from pyniche.data.huggingface.detection import hf_to_yolo
# huggingface imports
import datasets
# 2. Global Variables ----------------------------------------------------------
ROOT = os.path.dirname(os.path.abspath(__file__))
DATASET = "Niche-Squad/COLO"
CONFIGS = [
"0_all",
"1_top",
"2_side",
"3_external",
"a1_t2s",
"a2_s2t",
"b_light",
"c_external",
]
# 3. Download Data -------------------------------------------------------------
def main(args):
DIR_DATA = args.dir_data
# THREADS_YOLO = args.threads
for config in CONFIGS:
dir_config = os.path.join(DIR_DATA, config)
hf_dataset = datasets.load_dataset(
DATASET,
config,
download_mode="force_redownload",
cache_dir=os.path.join(DIR_DATA, ".huggingface"),
)
# convert to YOLO for YOLOv8
hf_to_yolo(
hf_dataset,
dir_config,
classes=["cow"],
size_new=(640, 640),
)
# # clone the datasets for multi-threading
# yolo_api = YOLO_API(dir_config)
# for thread in range(THREADS_YOLO):
# yolo_api.clone("run_%d" % thread)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dir_data")
# parser.add_argument("--threads", default=2)
args = parser.parse_args()
main(args)