Skip to content
This repository has been archived by the owner on Nov 19, 2024. It is now read-only.

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
luost26 committed Jan 20, 2022
1 parent 7bae7f2 commit cb8b91a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ dmypy.json

.DS_Store
/logs*
/outputs*
10 changes: 2 additions & 8 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
# Datasets

- Datasets and splits are available at: [https://drive.google.com/drive/folders/1CzwxmTpjbrt83z_wBzcQncq84OVDPurM](https://drive.google.com/drive/folders/1CzwxmTpjbrt83z_wBzcQncq84OVDPurM)
- Please download both `crossdocked_pocket10.tar.gz` and `split_by_name.pt` from the link above.
- Then, extract `crossdocked_pocket10.tar.gz` here.


- `test_list.tsv` is a readable list of protein-ligand pairs.


1. Download the dataset archive `crossdocked_pocket10.tar.gz` and the split file `split_by_name.pt` from [this link](https://drive.google.com/drive/folders/1CzwxmTpjbrt83z_wBzcQncq84OVDPurM).
2. Extract the TAR archive using the command: `tar -xzvf crossdocked_pocket10.tar.gz`.
6 changes: 5 additions & 1 deletion models/sample_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def grid_refine(pos_init, batch, model, radius=0.5, resolution=0.1, device=None)
y_flat = y_cls.flatten()
p = (y_flat - y_flat.logsumexp(dim=0)).exp()
p_argmax = torch.multinomial(p, 1)[0]
pos_idx, type_idx = p_argmax // y_cls.size(1), p_argmax % y_cls.size(1)

# pos_idx, type_idx = p_argmax // y_cls.size(1), p_argmax % y_cls.size(1)
# [NOTE] operator // is deprecated by the latest version of PyTorch, use the following torch.div instead
pos_idx = torch.div(p_argmax, y_cls.size(1), rounding_mode='floor')
type_idx = p_argmax % y_cls.size(1)

pos_refined.append(pos_query[pos_idx].view(1, 3))
y_cls_refined.append(y_cls[pos_idx].view(1, -1))
Expand Down

0 comments on commit cb8b91a

Please sign in to comment.