Skip to content

Commit

Permalink
Adds unit test for columnvector function
Browse files Browse the repository at this point in the history
Summary: Adds unit test to the test_processing.py for columnvector function from transform.py

Reviewed By: igfox

Differential Revision: D31247953

fbshipit-source-id: 8e6eee0fecf3dfb0bff8fb3d168e15f002c0acf3
  • Loading branch information
Yunus Emre authored and facebook-github-bot committed Sep 29, 2021
1 parent 5f0b21e commit 57f27db
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion reagent/test/preprocessing/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.testing as npt
import six
import torch
from reagent.preprocessing import identify_types, normalization
from reagent.preprocessing import identify_types, normalization, transforms
from reagent.preprocessing.identify_types import BOXCOX, CONTINUOUS, ENUM
from reagent.preprocessing.normalization import (
MISSING_VALUE,
Expand Down Expand Up @@ -363,3 +363,46 @@ def test_type_override_quantile(self):
"_", probability_values, feature_type=identify_types.QUANTILE
)
self.assertEqual(parameter.feature_type, "QUANTILE")

def test_columnvector(self):
def format_input2output(test_keys, inp_form):
test_data = {}
for ky in test_keys:
test_data[ky] = inp_form
test_instance = transforms.ColumnVector(test_keys)
output_data = test_instance(test_data)
return output_data

test_values = range(0, 5)
test_keys = []
for k in test_values:
test_keys.append(str(k))

# Possible input formats: tuple, list, torch.Tensor
for n_len in [1, 3]:
test_input_forms = [
(np.ones((n_len, 1)), 0),
n_len * [1],
torch.tensor(np.ones((n_len, 1))),
]
for inp_form in test_input_forms:
output_data = format_input2output(test_keys, inp_form)
for ky in test_keys:
self.assertEqual(output_data[ky].shape[0], n_len)
self.assertEqual(output_data[ky].shape[1], 1)

# Input as in row format
test_data = {}
for ky in test_keys:
test_data[ky] = (np.ones((1, 3)), 0)
test_instance = transforms.ColumnVector(test_keys)
with self.assertRaisesRegex(AssertionError, "Invalid shape for key"):
output_data = test_instance(test_data)

# Input as unimplemented type (number)
test_data = {}
for ky in test_keys:
test_data[ky] = 1
test_instance = transforms.ColumnVector(test_keys)
with self.assertRaisesRegex(NotImplementedError, "value of type"):
output_data = test_instance(test_data)

0 comments on commit 57f27db

Please sign in to comment.