Skip to content

Commit

Permalink
Don't forget to install torch
Browse files Browse the repository at this point in the history
  • Loading branch information
ostrokach committed Mar 26, 2021
1 parent f914c2c commit 48b2928
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Replace `cu101` with the desired CUDA version.

```bash
pip install -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html --default-timeout=600 \
"torch==1.8.1" \
"transformers==3.3.1" \
"torch-scatter==2.0.6" \
"torch-sparse==0.6.9" \
Expand Down
16 changes: 14 additions & 2 deletions scripts/get_pip_install_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def clean_notebook_text(text):
return text.strip()


def command_to_colab(command):
command = command.replace('"torch==1.8.1" ', "")
return command


def get_match(text, command):
idx = text.index("https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html")
match = text[idx - 15 : idx + len(command) - 15]
Expand All @@ -31,6 +36,7 @@ def get_match(text, command):

command = r"""
pip install -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html --default-timeout=600 \
"torch==1.8.1" \
"transformers==3.3.1" \
"torch-scatter==2.0.6" \
"torch-sparse==0.6.9" \
Expand All @@ -50,10 +56,16 @@ def get_match(text, command):

with ROOT_DIR.joinpath("notebooks", "10_stability_demo.ipynb").open("rt") as fin:
notebook_text = clean_notebook_text(fin.read())
assert command in notebook_text, (command, get_match(notebook_text, command))
assert command_to_colab(command) in notebook_text, (
command_to_colab(command),
get_match(notebook_text, command_to_colab(command)),
)

with ROOT_DIR.joinpath("notebooks", "10_affinity_demo.ipynb").open("rt") as fin:
notebook_text = clean_notebook_text(fin.read())
assert command in notebook_text, (command, get_match(notebook_text, command))
assert command_to_colab(command) in notebook_text, (
command_to_colab(command),
get_match(notebook_text, command_to_colab(command)),
)

print(command)

0 comments on commit 48b2928

Please sign in to comment.