Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Jul 18, 2024
1 parent 15296af commit dee3863
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 248 deletions.
2 changes: 1 addition & 1 deletion src/fairseq2/datasets/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def to_batch(example: Dict[str, Any]) -> SequenceBatch:
pipeline = builder.map(to_batch).and_return()

return DataPipelineReader[SequenceBatch](
pipeline, gang, drop_remainder=drop_remainder, sync_batches=True
pipeline, gang, drop_remainder=drop_remainder, sync_batches=sync_batches
)

def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuilder:
Expand Down
41 changes: 39 additions & 2 deletions src/fairseq2/metrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Union
from __future__ import annotations

from typing import Iterable, Optional, Union

import torch
from torch import Tensor
from torcheval.metrics import Max as MaxBase
from torcheval.metrics import Mean as MeanBase
from torcheval.metrics import Metric
from torcheval.metrics import Min as MinBase
from torcheval.metrics import Sum as SumBase
from typing_extensions import Self

from fairseq2.typing import override
from fairseq2.typing import Device, override


class Min(MinBase):
Expand Down Expand Up @@ -77,3 +80,37 @@ def update(
super().update(input_, weight=weight)

return self


class MaxSum(Metric[Tensor]):
"""Calculate the sum of all elements in all the input tensors locally and
take the maximum value when merged with other metrics."""

sum_: Tensor

def __init__(self, *, device: Optional[Device] = None) -> None:
super().__init__(device=device)

sum_ = torch.zeros((), device=device, dtype=torch.int64)

self._add_state("sum_", sum_)

@override
@torch.inference_mode()
def update(self, input_: Union[int, Tensor]) -> Self:
self.sum_ += input_

return self

@override
@torch.inference_mode()
def compute(self) -> Tensor:
return self.sum_

@override
@torch.inference_mode()
def merge_state(self, metrics: Iterable[MaxSum]) -> Self:
for metric in metrics:
self.sum_ = torch.max(self.sum_, metric.sum_.to(self.device))

return self
24 changes: 20 additions & 4 deletions src/fairseq2/metrics/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import json
import math
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -36,7 +37,10 @@

def format_as_int(value: Any, *, postfix: Optional[str] = None) -> str:
"""Format metric ``value`` as integer."""
i = int(value)
try:
i = int(value)
except ValueError:
return f"{value}"

s = "<1" if i == 0 and isinstance(value, float) else f"{i:,}"

Expand All @@ -52,7 +56,10 @@ def format_as_int(value: Any, *, postfix: Optional[str] = None) -> str:

def format_as_float(value: Any, *, postfix: Optional[str] = None) -> str:
"""Format metric ``value`` as float."""
s = f"{float(value):g}"
try:
s = f"{float(value):g}"
except ValueError:
return f"{value}"

if postfix:
s += postfix
Expand All @@ -67,7 +74,13 @@ def format_as_byte_size(value: Any) -> str:
"""Format metric ``value`` in byte units."""
unit_idx = 0

size = float(value)
try:
size = float(value)
except ValueError:
return f"{value}"

if not math.isfinite(size) or size <= 0.0:
return "0 B"

while size >= 1024:
size /= 1024
Expand Down Expand Up @@ -336,7 +349,10 @@ def sanitize(value: Any, formatter: _MetricFormatter) -> Any:
value = value.item()

if formatter.fn is format_as_int:
value = int(value)
try:
value = int(value)
except ValueError:
pass

return value

Expand Down
Loading

0 comments on commit dee3863

Please sign in to comment.