Skip to content

Commit

Permalink
update files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliaWolleb authored Dec 28, 2021
1 parent 08ce61e commit eeb28e3
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 82 deletions.
3 changes: 3 additions & 0 deletions guided_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Codebase for " Diffusion Models for Implicit Image Segmentation Ensembles".
"""
1 change: 0 additions & 1 deletion guided_diffusion/bratsloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __getitem__(self, x):
filedict = self.database[x]
for seqtype in self.seqtypes:
nib_img = nibabel.load(filedict[seqtype])
print('path', filedict[seqtype])
path=filedict[seqtype]
out.append(torch.tensor(nib_img.get_fdata()))
out = torch.stack(out)
Expand Down
1 change: 0 additions & 1 deletion guided_diffusion/dist_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def setup_dist():
s.listen(1)
port = s.getsockname()[1]
s.close()
print('port2', port)
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group(backend=backend, init_method="env://")

Expand Down
1 change: 0 additions & 1 deletion guided_diffusion/fp16_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def backward(self, loss: th.Tensor):
(loss * loss_scale).backward()
else:
loss.backward()
print('grad0', loss.grad)

def optimize(self, opt: th.optim.Optimizer):
if self.use_fp16:
Expand Down
44 changes: 3 additions & 41 deletions guided_diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
This code started out as a PyTorch port of Ho et al's diffusion models:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""
from torch.autograd import Variable
Expand Down Expand Up @@ -29,7 +28,6 @@ def standardize(img):
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
"""
Get a pre-defined beta schedule for the given name.
The beta schedule library consists of beta schedules which remain similar
in the limit of num_diffusion_timesteps.
Beta schedules may be added, but should not be removed or changed once
Expand Down Expand Up @@ -57,7 +55,6 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
Expand Down Expand Up @@ -86,7 +83,6 @@ class ModelMeanType(enum.Enum):
class ModelVarType(enum.Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
Expand All @@ -112,10 +108,8 @@ def is_vb(self):
class GaussianDiffusion:
"""
Utilities for training and sampling diffusion models.
Ported directly from here, and then adapted over time to further experimentation.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
:param betas: a 1-D numpy array of betas for each diffusion timestep,
starting at T and going to 1.
:param model_mean_type: a ModelMeanType determining what the model outputs.
Expand Down Expand Up @@ -182,7 +176,6 @@ def __init__(
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
Expand All @@ -199,9 +192,7 @@ def q_mean_variance(self, x_start, t):
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
Expand All @@ -219,9 +210,7 @@ def q_sample(self, x_start, t, noise=None):
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
Expand All @@ -247,7 +236,6 @@ def p_mean_variance(
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
Expand Down Expand Up @@ -374,7 +362,6 @@ def condition_mean(self, cond_fn, p_mean_var, x, t, org, model_kwargs=None):
computes the gradient of a conditional log probability with respect to
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
condition on y.
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
a, gradient = cond_fn(x, self._scale_timesteps(t),org, **model_kwargs)
Expand All @@ -389,9 +376,7 @@ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
"""
Compute what the p_mean_variance output would have been, should the
model's score function be conditioned by cond_fn.
See condition_mean() for details on cond_fn.
Unlike condition_mean(), this instead uses the conditioning strategy
from Song et al (2020).
"""
Expand Down Expand Up @@ -425,7 +410,6 @@ def p_sample(
):
"""
Sample x_{t-1} from the model at the given timestep.
:param model: the model to sample from.
:param x: the current tensor at x_{t-1}.
:param t: the value of t, starting at 0 for the first diffusion step.
Expand Down Expand Up @@ -471,7 +455,6 @@ def p_sample_loop(
):
"""
Generate samples from the model.
:param model: the model module.
:param shape: the shape of the samples, (N, C, H, W).
:param noise: if specified, the noise from the encoder to sample.
Expand Down Expand Up @@ -517,7 +500,6 @@ def p_sample_loop_known(
model_kwargs=None,
device=None,
progress=False,
conditioning=False,
conditioner = None,
classifier=None
):
Expand All @@ -540,16 +522,11 @@ def p_sample_loop_known(
model_kwargs=model_kwargs,
device=device,
progress=progress,
conditioning=conditioning,
conditioner=conditioner,
classifier=classifier
):
final = sample
if conditioning:
return final["sample"], x_noisy, org
else:

return final["sample"], x_noisy, img

return final["sample"], x_noisy, img

def p_sample_loop_progressive(
self,
Expand All @@ -568,7 +545,6 @@ def p_sample_loop_progressive(
"""
Generate samples from the model and yield intermediate samples from
each timestep of diffusion.
Arguments are the same as p_sample_loop().
Returns a generator over dicts, where each dict is the return value of
p_sample().
Expand Down Expand Up @@ -607,10 +583,7 @@ def p_sample_loop_progressive(
t,
clip_denoised=clip_denoised,
denoised_fn=denoised_fn,
cond_fn=cond_fn,
org=org,
model_kwargs=model_kwargs,
update_eps=True
)
yield out
img = out["sample"]
Expand All @@ -628,7 +601,6 @@ def ddim_sample(
):
"""
Sample x_{t-1} from the model using DDIM.
Same usage as p_sample().
"""
out = self.p_mean_variance(
Expand Down Expand Up @@ -768,7 +740,6 @@ def ddim_sample_loop(
):
"""
Generate samples from the model using DDIM.
Same usage as p_sample_loop().
"""
final = None
Expand Down Expand Up @@ -860,7 +831,6 @@ def ddim_sample_loop_progressive(
"""
Use DDIM to sample from the model and yield intermediate samples from
each timestep of DDIM.
Same usage as p_sample_loop_progressive().
"""
if device is None:
Expand Down Expand Up @@ -904,10 +874,8 @@ def _vb_terms_bpd(
):
"""
Get a term for the variational lower-bound.
The resulting units are bits (rather than nats, as one might expect).
This allows for comparison to other papers.
:return: a dict with the following keys:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
Expand Down Expand Up @@ -939,7 +907,6 @@ def _vb_terms_bpd(
def training_losses_segmentation(self, model, classifier, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param t: a batch of timestep indices.
Expand Down Expand Up @@ -1013,9 +980,7 @@ def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
Expand All @@ -1031,13 +996,11 @@ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
"""
Compute the entire variational lower-bound, measured in bits-per-dim,
as well as other related quantities.
:param model: the model to evaluate loss on.
:param x_start: the [N x C x ...] tensor of inputs.
:param clip_denoised: if True, clip denoised samples.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict containing the following keys:
- total_bpd: the total variational lower-bound, per batch element.
- prior_bpd: the prior term in the lower-bound.
Expand Down Expand Up @@ -1089,7 +1052,6 @@ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
Expand All @@ -1099,4 +1061,4 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)
return res.expand(broadcast_shape)
29 changes: 3 additions & 26 deletions guided_diffusion/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,6 @@ def run_loop(self):
# reinitialize data loader
data_iter = iter(self.dataloader)
batch, cond = next(data_iter)
# viz.image(visualize(batch[0, 0, ...]), opts=dict(caption="batch0"))
# viz.image(visualize(batch[0, 1, ...]), opts=dict(caption="batch1"))
# viz.image(visualize(batch[0, 2, ...]), opts=dict(caption="batch2"))
# viz.image(visualize(batch[0, 3, ...]), opts=dict(caption="batch3"))
# viz.image(visualize(cond[0, 0, ...]), opts=dict(caption="cond0"))

self.run_step(batch, cond)

Expand All @@ -198,22 +193,7 @@ def run_loop(self):
totseg += lossseg;
totcls += losscls
totrec += lossrec
if i % 10 == 0:
viz.line(X=th.ones((1, 1)).cpu() * i, Y=th.Tensor([totcls]).unsqueeze(0).cpu(),
win=loss_window, name='loss_vb',
update='append')
viz.line(X=th.ones((1, 1)).cpu() * i, Y=th.Tensor([totseg]).unsqueeze(0).cpu(),
win=loss_window, name='loss_seg',
update='append')
viz.line(X=th.ones((1, 1)).cpu() * i, Y=th.Tensor([totrec]).unsqueeze(0).cpu(),
win=loss_window, name='loss_rec',
update='append')
totseg = 0
totcls = 0
totrec=0

if i % 200 == 0:
viz.image(visualize(sample[0, 0, ...]), opts=dict(caption="sampled output"))


if self.step % self.log_interval == 0:
logger.dumpkvs()
Expand Down Expand Up @@ -244,14 +224,11 @@ def forward_backward(self, batch, cond):
self.mp_trainer.zero_grad()
for i in range(0, batch.shape[0], self.microbatch):
micro = batch[i : i + self.microbatch].to(dist_util.dev())
# micro_cond = cond[i: i + self.microbatch].to(dist_util.dev())
micro_cond = {
k: v[i : i + self.microbatch].to(dist_util.dev())
for k, v in cond.items()
}

# viz.image(visualize(batch[0,0, ...]), opts=dict(caption="micro"))
# viz.image(visualize(batch[0,4, ...]), opts=dict(caption="micro_cond"))
last_batch = (i + self.microbatch) >= batch.shape[0]
t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

Expand Down Expand Up @@ -280,8 +257,8 @@ def forward_backward(self, batch, cond):

loss = (losses["loss"] * weights).mean()
lossseg = (losses["mse"] * weights).mean().detach()
losscls = (losses["vb"] * weights).mean().detach()#0.1*( (losses["cls2"] * weights).mean().detach()+(losses["cls3"] * weights).mean().detach())
lossrec =loss*0#10*( (losses["rec1"] * weights).mean().detach()+(losses["rec2"] * weights).mean().detach())
losscls = (losses["vb"] * weights).mean().detach()
lossrec =loss*0

log_loss_dict(
self.diffusion, t, {k: v * weights for k, v in losses.items()}
Expand Down
17 changes: 5 additions & 12 deletions scripts/segmentation_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def dice_score(pred, targs):

def main():
args = create_argparser().parse_args()
#result0=th.load('./Bratssliced/validation/000246/result0')
# print('loadedresult0', result0.shape)
dist_util.setup_dist()
logger.configure()

Expand All @@ -60,7 +58,6 @@ def main():
shuffle=False)
data = iter(datal)
all_images = []
all_labels = []
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
Expand All @@ -69,19 +66,16 @@ def main():
model.convert_to_fp16()
model.eval()
while len(all_images) * args.batch_size < args.num_samples:
img, path = next(data) #should return an image from the dataloader "data"
# c = th.randn_like(b[:, :1, ...])
# img = th.cat((b, c), dim=1) #add a noise channel$
print('path', path)
b, path = next(data) #should return an image from the dataloader "data"
c = th.randn_like(b[:, :1, ...])
img = th.cat((b, c), dim=1) #add a noise channel$
slice_ID=path[0].split("/", -1)[3]

cond = {}
viz.image(visualize(img[0,0,...]), opts=dict(caption="img input0"))
viz.image(visualize(img[0, 1, ...]), opts=dict(caption="img input1"))
viz.image(visualize(img[0, 2, ...]), opts=dict(caption="img input2"))
viz.image(visualize(img[0, 3, ...]), opts=dict(caption="img input3"))
# viz.image(visualize(img[0, 4, ...]), opts=dict(caption="img input4"))

viz.image(visualize(img[0, 4, ...]), opts=dict(caption="img input4"))

logger.log("sampling...")

Expand All @@ -107,9 +101,8 @@ def main():
print('time for 1 sample', start.elapsed_time(end)) #time measurement for the generation of 1 sample

s = th.tensor(sample)
mask = th.where(sample > 0.5, 1, 0)
viz.image(visualize(sample[0, 0, ...]), opts=dict(caption="sampled output"))
th.save(s, './results/generated_masks/'+str(slice_ID)+'_output'+str(i)) #save the generated mask
th.save(s, './results/'+str(slice_ID)+'_output'+str(i)) #save the generated mask

def create_argparser():
defaults = dict(
Expand Down

0 comments on commit eeb28e3

Please sign in to comment.