Skip to content

Commit

Permalink
add: image save
Browse files Browse the repository at this point in the history
  • Loading branch information
youuijin committed Nov 18, 2023
1 parent 024666e commit 436cb25
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ data/
__pycache__/
images/
test.py
logs/
logs/
saved_images/*
12 changes: 11 additions & 1 deletion Manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torchvision import transforms
import torch, os


class Manager():
def __init__(self, model_name='mobilenet', attack='PGD_Linf', use=True):
Expand All @@ -8,6 +11,8 @@ def __init__(self, model_name='mobilenet', attack='PGD_Linf', use=True):
self.time = self.now.strftime('%m-%d_%H%M%S')
self.use = use
self.log_path = f'./logs/{model_name}_{attack}/{self.time}'
self.img_save_path = f'./saved_images/{model_name}_{attack}/{self.time}'
self.transform = transforms.ToPILImage()
if self.use :
self.writer = SummaryWriter(self.log_path)

Expand All @@ -19,4 +24,9 @@ def get_time(self):
return self.now

def get_log_path(self):
return self.log_path
return self.log_path

def save_image(self, episode, step, img):
os.makedirs(f"{self.img_save_path}/episode_{episode}/", exist_ok=True)
path = f"{self.img_save_path}/episode_{episode}/step_{step}.png"
self.transform(torch.tensor(img)).save(path)
7 changes: 7 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def parse_opt(known=False):
parser.add_argument("-a", "--alpha", type=float, default=0.5, help="hyperparameter alpha for cal Reward")
parser.add_argument("-name", "--model_name", type=str, default="mobilenet", help="attacked DNN model name")
parser.add_argument("-dataset", "--dataset_name", type=str, default="CIFAR10", help="train dataset name")

parser.add_argument("-save", "--image_save", action='store_true', default=False, help="save step images")

return parser.parse_known_args()[0] if known else parser.parse_args()

Expand All @@ -56,6 +58,7 @@ def main(conf):
num_step = conf["num_step"]
mode = conf["mode"]
print_interval = 100
save_interval = 500

# Env, Agent setting
env = Env(conf)
Expand All @@ -69,6 +72,8 @@ def main(conf):
for episode in tqdm(range(num_episode)):
epi_reward = 0
state, _ = env.reset()
if conf['image_save'] and episode%save_interval==0:
manager.save_image(episode, 0, state[0]) # 변화 없는 이미지 = 0
done = False
for step in range(num_step):
actions, action_probs = agent.get_actions(state)
Expand All @@ -79,6 +84,8 @@ def main(conf):
reward = reward.item()
epi_reward += reward
state = state_prime
if conf['image_save'] and episode%save_interval==0:
manager.save_image(episode, step+1, state[0])
if done :
break
loss = agent.train_net()
Expand Down

0 comments on commit 436cb25

Please sign in to comment.