Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Abid committed Aug 10, 2020
1 parent c90fc29 commit b8707fb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 64 deletions.
50 changes: 6 additions & 44 deletions test/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,60 +11,22 @@
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'


class TestSketchpad(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Sketchpad()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

class TestImage(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.Sketchpad()
inp = inputs.Image(shape=(20, 20))
array = inp.preprocess(BASE64_SKETCH)
self.assertEqual(array.shape, (1, 28, 28))


class TestWebcam(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Webcam()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

def test_preprocessing(self):
inp = inputs.Webcam()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (224, 224, 3))
self.assertEqual(array.shape, (20, 20, 3))
inp2 = inputs.Image(shape=(20, 20), image_mode="L")
array2 = inp2.preprocess(BASE64_SKETCH)
self.assertEqual(array2.shape, (20, 20))


class TestTextbox(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Textbox()
path = BASE_INPUT_INTERFACE_JS_PATH.format(
inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

def test_preprocessing(self):
inp = inputs.Textbox()
string = inp.preprocess(RAND_STRING)
self.assertEqual(string, RAND_STRING)


class TestImageUpload(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Image()
path = BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

def test_preprocessing(self):
inp = inputs.Image()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (224, 224, 3))

def test_preprocessing(self):
inp = inputs.Image()
inp.image_height = 48
inp.image_width = 48
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (48, 48, 3))

if __name__ == '__main__':
unittest.main()
10 changes: 5 additions & 5 deletions test/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@

class TestInterface(unittest.TestCase):
def test_input_output_mapping(self):
io = gr.Interface(inputs='SketCHPad', outputs='TexT', fn=lambda x: x)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
io = gr.Interface(inputs='sketchpad', outputs='text', fn=lambda x: x)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Image)
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)

def test_input_interface_is_instance(self):
inp = gradio.inputs.Image()
io = gr.Interface(inputs=inp, outputs='teXT', fn=lambda x: x)
io = gr.Interface(inputs=inp, outputs='text', fn=lambda x: x)
self.assertEqual(io.input_interfaces[0], inp)

def test_output_interface_is_instance(self):
out = gradio.outputs.Label()
io = gr.Interface(inputs='SketCHPad', outputs=out, fn=lambda x: x)
io = gr.Interface(inputs='sketchpad', outputs=out, fn=lambda x: x)
self.assertEqual(io.output_interfaces[0], out)

def test_prediction(self):
def model(x):
return 2*x
io = gr.Interface(inputs='textbox', outputs='TEXT', fn=model)
io = gr.Interface(inputs='textbox', outputs='text', fn=model)
self.assertEqual(io.predict[0](11), 22)


Expand Down
15 changes: 0 additions & 15 deletions test/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js'

class TestLabel(unittest.TestCase):
def test_path_exists(self):
out = outputs.Label()
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
Expand Down Expand Up @@ -52,11 +47,6 @@ def test_postprocessing_int(self):


class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
Expand All @@ -65,11 +55,6 @@ def test_postprocessing(self):


class TestImage(unittest.TestCase):
def test_path_exists(self):
out = outputs.Image()
path = BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))

def test_postprocessing(self):
string = BASE64_IMG
out = outputs.Textbox()
Expand Down

0 comments on commit b8707fb

Please sign in to comment.