Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (4/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent 7996b89 commit ce8b2f7
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 41 deletions.
12 changes: 3 additions & 9 deletions tests/metrics/window/test_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,8 @@ def _test_mean_squared_error_class_with_input(
-1, int(torch.numel(target) / (NUM_TOTAL_UPDATES * BATCH_SIZE))
).squeeze()

target_window = target[
(-1) * max_num_updates :,
]
input_window = input[
(-1) * max_num_updates :,
]
target_window = target[(-1) * max_num_updates :,]
input_window = input[(-1) * max_num_updates :,]

target_window_np = target_window.reshape(
-1, int(torch.numel(target_window) / (max_num_updates * BATCH_SIZE))
Expand All @@ -54,9 +50,7 @@ def _test_mean_squared_error_class_with_input(
sample_weight_np = sample_weight.reshape(
-1, int(torch.numel(sample_weight) / (NUM_TOTAL_UPDATES * BATCH_SIZE))
).squeeze()
sample_weight_window = sample_weight[
(-1) * max_num_updates :,
]
sample_weight_window = sample_weight[(-1) * max_num_updates :,]
sample_weight_window_np = sample_weight_window.reshape(
-1,
int(torch.numel(sample_weight_window) / (max_num_updates * BATCH_SIZE)),
Expand Down
2 changes: 1 addition & 1 deletion torcheval/metrics/functional/text/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _bleu_score_update(
matches_by_order = torch.zeros(n_gram, device=device)
possible_matches_by_order = torch.zeros(n_gram, device=device)

for (candidate, references) in zip(input_, target_):
for candidate, references in zip(input_, target_):
candidate_tokenized = candidate.split()
references_tokenized = [ref.split() for ref in references]

Expand Down
6 changes: 2 additions & 4 deletions torcheval/metrics/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,15 @@ def _sync_metric_object(
local_metric_data: Metric,
process_group: dist.ProcessGroup,
world_size: int,
) -> List[Metric]:
...
) -> List[Metric]: ...


@overload
def _sync_metric_object(
local_metric_data: MutableMapping[str, Metric],
process_group: dist.ProcessGroup,
world_size: int,
) -> List[MutableMapping[str, Metric]]:
...
) -> List[MutableMapping[str, Metric]]: ...


def _apply_device_to_tensor_states(
Expand Down
12 changes: 6 additions & 6 deletions torcheval/metrics/window/click_through_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,12 @@ def merge_state(
self.click_total += metric.click_total.to(self.device)
self.weight_total += metric.weight_total.to(self.device)
cur_size = min(metric.total_updates, metric.max_num_updates)
self.windowed_click_total[
:, idx : idx + cur_size
] = metric.windowed_click_total[:, :cur_size]
self.windowed_weight_total[
:, idx : idx + cur_size
] = metric.windowed_weight_total[:, :cur_size]
self.windowed_click_total[:, idx : idx + cur_size] = (
metric.windowed_click_total[:, :cur_size]
)
self.windowed_weight_total[:, idx : idx + cur_size] = (
metric.windowed_weight_total[:, :cur_size]
)
idx += cur_size
self.total_updates += metric.total_updates

Expand Down
12 changes: 6 additions & 6 deletions torcheval/metrics/window/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,12 @@ def merge_state(
self.sum_squared_error += metric.sum_squared_error.to(self.device)
self.sum_weight += metric.sum_weight.to(self.device)
cur_size = min(metric.total_updates, metric.max_num_updates)
self.windowed_sum_squared_error[
:, idx : idx + cur_size
] = metric.windowed_sum_squared_error[:, :cur_size]
self.windowed_sum_weight[
:, idx : idx + cur_size
] = metric.windowed_sum_weight[:, :cur_size]
self.windowed_sum_squared_error[:, idx : idx + cur_size] = (
metric.windowed_sum_squared_error[:, :cur_size]
)
self.windowed_sum_weight[:, idx : idx + cur_size] = (
metric.windowed_sum_weight[:, :cur_size]
)
idx += cur_size
self.total_updates += metric.total_updates

Expand Down
18 changes: 9 additions & 9 deletions torcheval/metrics/window/normalized_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,15 @@ def merge_state(
self.num_examples += metric.num_examples.to(self.device)
self.num_positive += metric.num_positive.to(self.device)
cur_size = min(metric.total_updates, metric.max_num_updates)
self.windowed_total_entropy[
:, idx : idx + cur_size
] = metric.windowed_total_entropy[:, :cur_size]
self.windowed_num_examples[
:, idx : idx + cur_size
] = metric.windowed_num_examples[:, :cur_size]
self.windowed_num_positive[
:, idx : idx + cur_size
] = metric.windowed_num_positive[:, :cur_size]
self.windowed_total_entropy[:, idx : idx + cur_size] = (
metric.windowed_total_entropy[:, :cur_size]
)
self.windowed_num_examples[:, idx : idx + cur_size] = (
metric.windowed_num_examples[:, :cur_size]
)
self.windowed_num_positive[:, idx : idx + cur_size] = (
metric.windowed_num_positive[:, :cur_size]
)
idx += cur_size
self.total_updates += metric.total_updates

Expand Down
12 changes: 6 additions & 6 deletions torcheval/metrics/window/weighted_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def merge_state(
self.weighted_input_sum += metric.weighted_input_sum.to(self.device)
self.weighted_target_sum += metric.weighted_target_sum.to(self.device)
cur_size = min(metric.total_updates, metric.max_num_updates)
self.windowed_weighted_input_sum[
:, idx : idx + cur_size
] = metric.windowed_weighted_input_sum[:, :cur_size]
self.windowed_weighted_target_sum[
:, idx : idx + cur_size
] = metric.windowed_weighted_target_sum[:, :cur_size]
self.windowed_weighted_input_sum[:, idx : idx + cur_size] = (
metric.windowed_weighted_input_sum[:, :cur_size]
)
self.windowed_weighted_target_sum[:, idx : idx + cur_size] = (
metric.windowed_weighted_target_sum[:, :cur_size]
)
idx += cur_size
self.total_updates += metric.total_updates
self.next_inserted = idx
Expand Down

0 comments on commit ce8b2f7

Please sign in to comment.