Collection of measures and metrics for automatic image quality assessment in various image-to-image tasks such as denoising, super-resolution, image generation etc. This easy to use yet flexible and extensive library is developed with focus on reliability and reproducibility of results. Use your favourite measures as losses for training neural networks with ready-to-use PyTorch modules.
import torch
from piq import ssim
prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
ssim_index = ssim(prediction, target, data_range=1.)
Blind/Referenceless Image Spatial Quality Evaluator (BRISQUE)
To compute BRISQUE score as a measure, use lower case function from the library:
import torch
from piq import brisque
from typing import Union, Tuple
prediction = torch.rand(3, 3, 256, 256)
brisque_index: torch.Tensor = brisque(prediction, data_range=1.)
In order to use BRISQUE as a loss function, use corresponding PyTorch module:
import torch
from piq import BRISQUELoss
loss = BRISQUELoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
output: torch.Tensor = loss(prediction)
output.backward()
Feature Similarity Index Measure (FSIM)
To compute FSIM as a measure, use lower case function from the library:
import torch
from piq import fsim
prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
vsi_index: torch.Tensor = fsim(prediction, target, data_range=1.)
In order to use FSIM as a loss function, use corresponding PyTorch module:
import torch
from piq import FSIMLoss
loss = FSIMLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
Frechet Inception Distance(FID)
Use FID
class to compute FID score from image features,
pre-extracted from some feature extractor network:
import torch
from piq import FID
fid_metric = FID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = fid_metric(prediction_feats, target_feats)
If image features are not available, extract them using _compute_feats
of FID
class.
Please note that _compute_feats
consumes a data loader of predefined format.
import torch
from torch.utils.data import DataLoader
from piq import FID
first_dl, second_dl = DataLoader(), DataLoader()
fid_metric = FID()
first_feats = fid_metric._compute_feats(first_dl)
second_feats = fid_metric._compute_feats(second_dl)
msid: torch.Tensor = fid_metric(first_feats, second_feats)
Geometry Score (GS)
Use GS
class to compute Geometry Score from image features,
pre-extracted from some feature extractor network. Computation is heavily CPU dependent, adjust num_workers
parameter according to your system configuration:
import torch
from piq import GS
gs_metric = GS(sample_size=64, num_iters=100, i_max=100, num_workers=4)
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
gs: torch.Tensor = gs_metric(prediction_feats, target_feats)
GS metric requiers gudhi
library which is not installed by default.
If you use conda, write: conda install -c conda-forge gudhi
, otherwise follow installation guide.
Gradient Magnitude Similarity Deviation (GMSD)
This is port of MATLAB version from the authors of original paper. It can be used both as a measure and as a loss function. In any case it should me minimized. Usually values of GMSD lie in [0, 0.35] interval.
import torch
from piq import GMSDLoss
loss = GMSDLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
Inception Score(IS)
Use inception_score
function to compute IS from image features,
pre-extracted from some feature extractor network. Note, that we follow recomendations from paper A Note on the Inception Score, which proposed small modification to original algorithm:
import torch
from piq import inception_score
prediction_feats = torch.rand(10000, 1024)
mean, variance = inception_score(prediction_feats, num_splits=10)
To compute difference between IS for 2 sets of image features, use IS
class.
import torch
from piq import IS
is_metric = IS(distance='l1')
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
distance: torch.Tensor = is_metric(prediction_feats, target_feats)
Kernel Inception Distance(KID)
Use KID
class to compute KID score from image features,
pre-extracted from some feature extractor network:
import torch
from piq import KID
kid_metric = KID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
kid: torch.Tensor = kid_metric(prediction_feats, target_feats)
If image features are not available, extract them using _compute_feats
of KID
class.
Please note that _compute_feats
consumes a data loader of predefined format.
import torch
from torch.utils.data import DataLoader
from piq import KID
first_dl, second_dl = DataLoader(), DataLoader()
kid_metric = KID()
first_feats = kid_metric._compute_feats(first_dl)
second_feats = kid_metric._compute_feats(second_dl)
kid: torch.Tensor = kid_metric(first_feats, second_feats)
Multi-Scale Intrinsic Distance (MSID)
Use MSID
class to compute MSID score from image features,
pre-extracted from some feature extractor network:
import torch
from piq import MSID
msid_metric = MSID()
prediction_feats = torch.rand(10000, 1024)
target_feats = torch.rand(10000, 1024)
msid: torch.Tensor = msid_metric(prediction_feats, target_feats)
If image features are not available, extract them using _compute_feats
of MSID
class.
Please note that _compute_feats
consumes a data loader of predefined format.
import torch
from torch.utils.data import DataLoader
from piq import MSID
first_dl, second_dl = DataLoader(), DataLoader()
msid_metric = MSID()
first_feats = msid_metric._compute_feats(first_dl)
second_feats = msid_metric._compute_feats(second_dl)
msid: torch.Tensor = msid_metric(first_feats, second_feats)
Multi-Scale Structural Similarity (MS-SSIM)
To compute MS-SSIM index as a measure, use lower case function from the library:
import torch
from piq import multi_scale_ssim
prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
ms_ssim_index: torch.Tensor = multi_scale_ssim(prediction, target, data_range=1.)
In order to use MS-SSIM as a loss function, use corresponding PyTorch module:
import torch
from piq import MultiScaleSSIMLoss
loss = MultiScaleSSIMLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
MultiScale GMSD (MS-GMSD)
It can be used both as a measure and as a loss function. In any case it should me minimized.
By defualt scale weights are initialized with values from the paper. You can change them by passing a list of 4 variables to scale_weights
argument during initialization. Both GMSD and MS-GMSD computed for greyscale images, but to take contrast changes into account authors propoced to also add chromatic component. Use flag chromatic
to use MS-GMSDc version of the loss
import torch
from piq import MultiScaleGMSDLoss
loss = MultiScaleGMSDLoss(chromatic=True, data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
Peak Signal-to-Noise Ratio (PSNR)
To compute PSNR as a measure, use lower case function from the library.
By default it computes average of PSNR if more than 1 image is included in batch.
You can specify other reduction methods by reduction
flag.
import torch
from piq import psnr
from typing import Union, Tuple
prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
psnr_mean = psnr(prediction, target, data_range=1., reduction='mean')
psnr_per_image = psnr(prediction, target, data_range=1., reduction='none')
Note: Colour images are first converted to YCbCr format and only luminance component is considered.
Structural Similarity (SSIM)
To compute SSIM index as a measure, use lower case function from the library:
import torch
from piq import ssim
from typing import Union, Tuple
prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
ssim_index: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = ssim(prediction, target, data_range=1.)
In order to use SSIM as a loss function, use corresponding PyTorch module:
import torch
from piq import SSIMLoss
loss = SSIMLoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
Total Variation (TV)
To compute TV as a measure, use lower case function from the library:
import torch
from piq import total_variation
data = torch.rand(3, 3, 256, 256)
tv: torch.Tensor = total_variation(data)
In order to use TV as a loss function, use corresponding PyTorch module:
import torch
from piq import TVLoss
loss = TVLoss()
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
output: torch.Tensor = loss(prediction)
output.backward()
Visual Information Fidelity (VIF)
To compute VIF as a measure, use lower case function from the library:
import torch
from piq import vif_p
predicted = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
vif: torch.Tensor = vif_p(predicted, target, data_range=1.)
In order to use VIF as a loss function, use corresponding PyTorch class:
import torch
from piq import VIFLoss
loss = VIFLoss(sigma_n_sq=2.0, data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
Note, that VIFLoss returns 1 - VIF
value.
Visual Saliency-induced Index (VSI)
To compute VSI score as a measure, use lower case function from the library:
import torch
from piq import vsi
prediction = torch.rand(3, 3, 256, 256)
target = torch.rand(3, 3, 256, 256)
vsi_index: torch.Tensor = vsi(prediction, target, data_range=1.)
In order to use VSI as a loss function, use corresponding PyTorch module:
import torch
from piq import VSILoss
loss = VSILoss(data_range=1.)
prediction = torch.rand(3, 3, 256, 256, requires_grad=True)
target = torch.rand(3, 3, 256, 256)
output: torch.Tensor = loss(prediction, target)
output.backward()
PyTorch Image Quality (former PhotoSynthesis.Metrics) helps you to concentrate on your experiments without the boilerplate code. The library contains a set of measures and metrics that is constantly getting extended. For measures/metrics that can be used as loss functions, corresponding PyTorch modules are implemented.
$ pip install piq
If you want to use the latest features straight from the master, clone the repo:
$ git clone https://github.com/photosynthesis-team/piq.git
See the open issues for a list of proposed features and known issues.
We appreciate all contributions. If you plan to:
- contribute back bug-fixes, please do so without any further discussion
- close one of open issues, please do so if no one has been assigned to it
- contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us
Please see the contribution guide for more information.
Sergey Kastryulin - @snk4tr - [email protected]
Project Link: https://github.com/photosynthesis-team/piq
PhotoSynthesis Team: https://github.com/photosynthesis-team
Other projects by PhotoSynthesis Team:
- Pavel Parunin - @PavelParunin - idea proposal and development
- Djamil Zakirov - @zakajd - development
- Denis Prokopenko - @denproc - development