Skip to content

Commit

Permalink
add download script
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoon Kim authored and Hoon Kim committed Mar 11, 2019
1 parent 61f0b7a commit 5ce036e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
Binary file added checkpoints/.DS_Store
Binary file not shown.
46 changes: 46 additions & 0 deletions checkpoints/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import requests, os

def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"

session = requests.Session()

response = session.get(URL, params = { 'id' : id }, stream = True)
token = get_confirm_token(response)

if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params = params, stream = True)

save_response_content(response, destination)

def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value

return None

def save_response_content(response, destination):
CHUNK_SIZE = 32768

with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)

if __name__ == "__main__":
print("Downloading pretrained model...")
file_id = str("1ZNJSAN-96CVUakwtd1YMW70t2YXSVIdn")
destination = str("./checkpoints/gtaCrash.1.0.t-1.8.zip")
download_file_from_google_drive(file_id, destination)
print("Download completed!")

print("Unzipping pretrained model...")
os.system("unzip ./checkpoints/gtaCrash.1.0.t-1.8.zip -d ./checkpoints")
print("Unzipping completed!")

os.system("rm ./checkpoints/gtaCrash.1.0.t-1.8.zip")
print("Zip file removed")


4 changes: 2 additions & 2 deletions scripts/train_gta_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
feature_extractor = 'vgg16'
optimizer = 'adam'
lr = 0.0001
decay_lr_per = 3
decay_lr_per = 2
nepoch = 10
init_steps_to_skip_eval = 1500

Expand All @@ -23,7 +23,7 @@
for n_RGBs, n_BBs in [(3,3)]:

os.system("python train.py " + \
"--name temp.gtaCrash.{train_proportion}.t-{ttc}s/{label_method}-{motion_model}/".format(
"--name gtaCrash.{train_proportion}.t-{ttc}s/{label_method}-{motion_model}/".format(
train_proportion=train_proportion, ttc=ttc,
label_method=label_method, motion_model=motion_model) +\
"{n_RGBs}rgb{n_BBs}b.{feature_extractor}.{optimizer}.lr{lr}.decay_lr_per{decay_lr_per}.seed{seed} ".format(
Expand Down

0 comments on commit 5ce036e

Please sign in to comment.