This library is currently in Alpha and currently does not have a stable release. The API may change and may not be backward compatible. If you have suggestions for improvements, please open a GitHub issue. We'd love to hear your feedback.
A library that contains a rich collection of performant PyTorch model metrics, a simple interface to create new metrics, a toolkit to facilitate metric computation in distributed training and tools for PyTorch model evaluations.
Requires Python >= 3.7 and PyTorch >= 1.11
From pip:
pip install torcheval
For nighly build version
pip install --pre torcheval-nightly
From source:
git clone https://github.com/pytorch-labs/torcheval
cd torcheval
pip install -r requirements.txt
python setup.py install
cd torcheval
python examples/simple_example.py
TorchEval can be run on CPU, GPU, and Multi-GPUs/Muti-Nodes.
For the multiple devices usage:
import torch
from torcheval.metrics.toolkit import sync_and_compute
from torcheval.metrics import MulticlassAccuracy
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(
f"cuda:{local_rank}"
if torch.cuda.is_available() and torch.cuda.device_count() >= world_size
else "cpu"
)
metric = MulticlassAccuracy(device=device)
num_epochs, num_batches = 4, 8
for epoch in range(num_epochs):
for i in range(num_batches):
input = torch.randint(high=5, size=(10,), device=device)
target = torch.randint(high=5, size=(10,), device=device)
# metric.update() updates the metric state with new data
metric.update(input, target)
# metric.compute() returns metric value from all seen data on the local process.
local_compute_result = metric.compute()
# sync_and_compute(metric) returns metric value from all seen data on all processes.
# It gives the same result as ``metric.compute()`` if it's run on single process.
global_compute_result = sync_and_compute(metric)
# The final result is collected by rank 0
if global_rank == 0:
print(global_compute_result)
# metric.reset() cleans up all seen data
metric.reset()
See the example directory for more examples.
We welcome PRs! See the CONTRIBUTING file.
TorchEval is BSD licensed, as found in the LICENSE file.