forked from instadeepai/aichor-demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
32 lines (24 loc) · 990 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
import time
from src.operators.jax import jaxop
from src.operators.ray import rayop
from src.operators.tf import tfop
from src.utils.tensorboard import dummy_tb_write
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='AIchor Smoke test on any operator')
parser.add_argument("--operator", type=str, default="tf", choices=["ray", "jax", "tf"],help="operator name")
parser.add_argument("--sleep", type=int, default="0", help="sleep time in seconds")
parser.add_argument("--tb-write", type=bool, default=False, help="test write to tensorboard")
args = parser.parse_args()
print(f"using {args.operator} operator")
if args.operator == "ray":
rayop()
elif args.operator == "jax":
jaxop()
elif args.operator == "tf":
tfop()
if args.tb_write:
dummy_tb_write()
if args.sleep > 0:
print(f"sleeping for {args.sleep}s before exiting")
time.sleep(args.sleep)