Skip to content

Commit

Permalink
update: huggingface source to ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
TangentOne committed Dec 13, 2024
1 parent 866974d commit b346ab0
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 194 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*data*
*.pyc*
*.json
ckpts/*
37 changes: 20 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

![Teaser image](./docs/selected_pics.png)

[Preprint] Generative Modeling with Explicit Memory <br>
Yi Tang, Peng Sun, Zhenglin Cheng, Tao Lin
**[Preprint] Generative Modeling with Explicit Memory** <br>
Yi Tang, Peng Sun, Zhenglin Cheng, Tao Lin <br>
https://arxiv.org/abs/2412.08781<br>


### Abstract
Expand All @@ -13,7 +14,7 @@ Yi Tang, Peng Sun, Zhenglin Cheng, Tao Lin
---


### System Requirements
### Requirements

- **Python and PyTorch:**
- 64-bit Python 3.10 or later.
Expand All @@ -30,31 +31,24 @@ Yi Tang, Peng Sun, Zhenglin Cheng, Tao Lin

### Getting Started

To reproduce the primary results from the paper, run the following script:
To reproduce the results from the paper, run the following script:

```bash
bash scripts/sample-gmem-xl.sh
```

This is a minimal standalone script that loads the best pre-trained model and generates 50K images.
**Important:** make sure to change `--ckpt` to correct path.

---

### Pre-trained Models and Memory Bank

We offer the following pre-trained models and memory bank here:
We offer the following pre-trained model and memory bank here:

#### Model Checkpoints
| Model Backbone | Training Steps | File Location |
|----------------------|----------------|------------------------------|
| SiT-XL/2 | 2M | [Download Here](#) |

#### Memory Bank
| Dataset | Resolution | Snippets | Training Epo. | File Location |
|----------------------|----------------|------------------|----------------|------------------------------|
| ImageNet | 256×256 | 640,000 | 5 | [Download Here](#) |

**Important:** Ensure that both `bank.pth` and `bank.freq` are saved in the same directory to enable proper functionality.
#### GMem Checkpoints
| Backbone | Training Steps | Dataset | Bank Size | Training Epo. | Download |
|----------------|----------------|---------------------------|-----------|---------------|----------|
| SiT-XL/2 | 2M | ImageNet $256\times 256$ | 640,000 | 5 | [Huggingface](https://huggingface.co/Tangentone/GMem) |

---

Expand All @@ -72,5 +66,14 @@ This code is mainly built upon [SiT](https://github.com/willisma/SiT), [edm2](ht
### BibTeX

```bibtex
@misc{tang2024generativemodelingexplicitmemory,
title={Generative Modeling with Explicit Memory},
author={Yi Tang and Peng Sun and Zhenglin Cheng and Tao Lin},
year={2024},
eprint={2412.08781},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.08781},
}
```

74 changes: 0 additions & 74 deletions create_env.sh

This file was deleted.

20 changes: 10 additions & 10 deletions engine/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,17 @@ def main(args):
# Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
ckpt_path = args.ckpt

state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')['ema']
model.load_state_dict(state_dict)
GMem_state_dict = torch.load(ckpt_path, map_location=f'cuda:{device}')

model.load_state_dict(GMem_state_dict['network'])
model.eval() # important!
vae = AutoencoderKL.from_pretrained(f"pretrains/stabilityai/sd-vae-ft-{args.vae}").to(device)

print(f'loading bank:{args.bank_path}')
bank_path: str = args.bank_path
bank = torch.load(f=bank_path).to('cpu')
freq_path = bank_path.replace('pth', 'freq')
freq = torch.load(f=freq_path)
print(f'Loading bank...')
bank = GMem_state_dict['memorybank']
freq = GMem_state_dict['memoryfreq']
print(f'Bank Loaded with {len(freq)} snippets!')


# Create folder to save samples:
folder_name = f"GMem-XL-2000000-ImageNet256x256-bank640000"
Expand Down Expand Up @@ -157,7 +158,6 @@ def main(args):

# logging/saving:
parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
parser.add_argument("--bank-path", type=str, default=None)
parser.add_argument("--sample-dir", type=str, default="outputs/samples")

# model
Expand All @@ -176,7 +176,7 @@ def main(args):
parser.add_argument("--num-fid-samples", type=int, default=50_000)

# sampling related hyperparameters
parser.add_argument("--mode", type=str, default="ode")
parser.add_argument("--mode", type=str, default="sde")
parser.add_argument("--cfg-scale", type=float, default=0)
parser.add_argument("--projector-embed-dims", type=str, default="768,1024")
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
Expand All @@ -185,7 +185,7 @@ def main(args):
parser.add_argument("--guidance-low", type=float, default=0.)
parser.add_argument("--guidance-high", type=float, default=0.)

# GMem
# GMem required!
parser.add_argument("--use-feature-condition", action=argparse.BooleanOptionalAction, default=False)


Expand Down
3 changes: 1 addition & 2 deletions scripts/sample-gmem-xl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ torchrun --nnodes=1 --nproc_per_node=8 engine/generate.py \
--cfg-scale=0 \
--guidance-high=0 \
--use-feature-condition \
--ckpt PATH/TO/CKPT \
--bank-path PATH/TO/BANK \
--ckpt "./ckpts/GMem_XL_2Miter_ImageNet-1K_K640000_5epo.pth" \ # CHANGE THIS!
91 changes: 0 additions & 91 deletions utils/utils.py

This file was deleted.

0 comments on commit b346ab0

Please sign in to comment.