Skip to content

Commit

Permalink
Can now get stats og real_images separately
Browse files Browse the repository at this point in the history
  • Loading branch information
evenmn committed May 6, 2021
1 parent b72d17a commit 021d3e3
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 30 deletions.
5 changes: 3 additions & 2 deletions pytorch_sfid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__version__ = "0.0.1"
from pytorch_sfid import params
__version__ = "0.0.2"
import pytorch_fid_wrapper as pfw
from pytorch_sfid import params
from pytorch_sfid.sfid import get_sfid, get_stats


def set_config(ncenters=None, radius=None, batch_size=None, dims=None,
Expand Down
133 changes: 105 additions & 28 deletions pytorch_sfid/sfid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,71 @@
from pytorch_sfid import params as ps_params


def get_stats(real):
pass
def get_bins(attr, ncenters, radius):
"""Sort attributes into bins and return indices
Parameters:
-----------
attr : torch FloatTensor, N x Na
Attributes
"""
# sort attributes
centers = torch.linspace(0, 1, ncenters)
lower_edge = centers - radius
upper_edge = centers + radius
smt = torch.logical_and(attr[:, :, None] > lower_edge[None, None, :],
attr[:, :, None] < upper_edge[None, None, :])

# sort into bins by doing the generalized out product between all rows
smt = smt.swapaxes(0, 1).swapaxes(1, 2)
bins = smt[0]
for i in range(1, attr.shape[1]):
bins = torch.einsum('i...,j...->ij...', bins, smt[i])

# return flattened bin tensor
return bins.reshape(-1, attr.shape[0])


def get_stats(real_images, real_attr, ncenters=None, radius=None,
batch_size=None, dims=None, device=None):
"""Get statistics of real images
Parameters:
-----------
real_images : torch FloatTensor, N x C x H x W
"""

if ncenters is None:
ncenters = ps_params.ncenters
if radius is None:
radius = ps_params.radius

min_attr, _ = real_attr.min(0)
max_attr, _ = real_attr.max(0)

real_attr = (real_attr - min_attr) / (max_attr - min_attr)

bins = get_bins(real_attr, ncenters, radius)

real_m, real_s = [], []
for i in range(bins.shape[0]):
indices = torch.where(bins[i])[0]
real_local = real_images[indices]
real_local = real_local.repeat(1, 3, 1, 1)

if real_local.shape[0] > 1:
m, s = pfw.get_stats(real_local, batch_size, dims, device)
real_m.append(m)
real_s.append(s)
else:
real_m.append(None)
real_s.append(None)
return real_m, real_s, min_attr, max_attr


def sfid_score(real_images, real_attr, fake_images, fake_attr, ncenters=None,
radius=None, batch_size=None, dims=None, device=None):
def get_sfid(fake_images, fake_attr, real_images=None, real_attr=None,
real_stats=None, ncenters=None, radius=None, batch_size=None,
dims=None, device=None):
"""Sliding Frechet Inception Distance
Parameters
Expand All @@ -22,30 +81,38 @@ def sfid_score(real_images, real_attr, fake_images, fake_attr, ncenters=None,
fake_attr : torch FloatTensor, Nf x Na
Attributes of fake images
"""

assert real_images.shape[0] == real_attr.shape[0]
assert (real_images is not None and real_attr is not None) or real_stats is not None
assert fake_images.shape[0] == fake_attr.shape[0]
assert real_attr.shape[1] == fake_attr.shape[1]

nreal = real_images.shape[0]
nfake = fake_images.shape[0]
ncond = real_attr.shape[1]
ncond = fake_attr.shape[1]
nbins = ncenters ** ncond

if ncenters is None:
ncenters = ps_params.ncenters
if radius is None:
radius = ps_params.radius

# standardize attributes (between 0 and 1)
attr = torch.cat([real_attr, fake_attr], dim=0)
if real_images is not None:
# standardize attributes (between 0 and 1)
attr = torch.cat([real_attr, fake_attr], dim=0)

min_attr, _ = attr.min(0)
max_attr, _ = attr.max(0)
min_attr, _ = attr.min(0)
max_attr, _ = attr.max(0)

real_attr = (real_attr - min_attr) / (max_attr - min_attr)
fake_attr = (fake_attr - min_attr) / (max_attr - min_attr)
real_attr = (real_attr - min_attr) / (max_attr - min_attr)
fake_attr = (fake_attr - min_attr) / (max_attr - min_attr)

# get bins
bins_real = get_bins(real_attr, ncenters, radius)
bins_fake = get_bins(fake_attr, ncenters, radius)

else:
real_m, real_s, min_attr, max_attr = real_stats
fake_attr = (fake_attr - min_attr) / (max_attr - min_attr)
bins_fake = get_bins(fake_attr, ncenters, radius)

'''
# sort attributes in (overlapping) Na-dimensional bins
centers = torch.linspace(0, 1, ncenters)
lower_edge = centers - radius
Expand All @@ -67,26 +134,36 @@ def sfid_score(real_images, real_attr, fake_images, fake_attr, ncenters=None,
bins_real = bins_real.reshape(-1, nreal)
bins_fake = bins_fake.reshape(-1, nfake)
'''

val_fid = 0
for i in range(nbins):
indices_real = torch.where(bins_real[i])[0]
indices_fake = torch.where(bins_fake[i])[0]
real_local = real_images[indices_real]
fake_local = fake_images[indices_fake]

real_local = real_local.repeat(1, 3, 1, 1)
fake_local = fake_local.repeat(1, 3, 1, 1)

if real_local.shape[0] > 1 and fake_local.shape[0] > 1:
print(real_local.shape[0], fake_local.shape[0])
val_fid += pfw.fid(fake_local, real_images=real_local,
batch_size=batch_size, dims=dims, device=device)
if real_images is not None:
indices_real = torch.where(bins_real[i])[0]
real_local = real_images[indices_real]
real_local = real_local.repeat(1, 3, 1, 1)

if real_local.shape[0] > 1 and fake_local.shape[0] > 1:
val_fid += pfw.fid(fake_local, real_images=real_local,
batch_size=batch_size, dims=dims, device=device)
else:
real_m_local, real_s_local = real_m[i], real_s[i]
if real_s_local is not None and fake_local.shape[0] > 1:
val_fid += pfw.fid(fake_local, real_m=real_m_local, real_s=real_s_local,
batch_size=batch_size, dims=dims, device=device)
print(val_fid)
return val_fid / nbins


if __name__ == "__main__":
real_images = torch.rand(10000, 1, 128, 128)
real_attr = torch.rand((10000, 3))
fake_attr = torch.rand((10000, 3))
sfid_score(real_images, real_attr, real_images, fake_attr, radius=0.19, ncenters=3)
N = 1000
real_images = torch.rand(N, 1, 128, 128)
real_attr = torch.rand((N, 3))
fake_attr = torch.rand((N, 3))

real_stats = get_stats(real_images, real_attr)
sfid_score(real_images, real_attr, real_stats=real_stats, radius=0.6, ncenters=3)

0 comments on commit 021d3e3

Please sign in to comment.