formatting

This commit is contained in:
Abubakar Abid 2022-03-23 15:50:10 -07:00
parent fa27aa5876
commit 63d18ac02b
5 changed files with 39 additions and 36 deletions

View File

@ -18,8 +18,8 @@ from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin from mdit_py_plugins.footnote import footnote_plugin
from gradio import interpretation, utils from gradio import interpretation, utils
from gradio.components import Component, get_component_instance, Markdown, Button from gradio.blocks import BlockContext, Column, Row
from gradio.blocks import BlockContext, Row, Column from gradio.components import Button, Component, Markdown, get_component_instance
from gradio.external import load_from_pipeline, load_interface # type: ignore from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import State as i_State # type: ignore from gradio.inputs import State as i_State # type: ignore

View File

@ -8,7 +8,6 @@ import huggingface_hub
import gradio as gr import gradio as gr
from gradio import flagging from gradio import flagging
# class TestDefaultFlagging(unittest.TestCase): # class TestDefaultFlagging(unittest.TestCase):
# def test_default_flagging_callback(self): # def test_default_flagging_callback(self):
# with tempfile.TemporaryDirectory() as tmpdirname: # with tempfile.TemporaryDirectory() as tmpdirname:

View File

@ -194,7 +194,7 @@ class TestSlider(unittest.TestCase):
"default": 15, "default": 15,
"name": "slider", "name": "slider",
"label": "Slide Your Input", "label": "Slide Your Input",
"css": {} "css": {},
}, },
) )
@ -530,7 +530,7 @@ class TestAudio(unittest.TestCase):
"source": "upload", "source": "upload",
"name": "audio", "name": "audio",
"label": "Upload Your Audio", "label": "Upload Your Audio",
"css": {} "css": {},
}, },
) )
self.assertIsNone(audio_input.preprocess(None)) self.assertIsNone(audio_input.preprocess(None))
@ -589,7 +589,7 @@ class TestFile(unittest.TestCase):
"file_count": "single", "file_count": "single",
"name": "file", "name": "file",
"label": "Upload Your File", "label": "Upload Your File",
"css": {} "css": {},
}, },
) )
self.assertIsNone(file_input.preprocess(None)) self.assertIsNone(file_input.preprocess(None))

View File

@ -1,8 +1,8 @@
import json
import os import os
import tempfile import tempfile
import unittest import unittest
import json
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -16,10 +16,8 @@ class TestTextbox(unittest.TestCase):
def test_in_interface(self): def test_in_interface(self):
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox()) iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"]) self.assertEqual(iface.process(["Hello"])[0], ["o"])
iface = gr.Interface( iface = gr.Interface(lambda x: x / 2, "number", gr.outputs.Textbox())
lambda x: x / 2, "number", gr.outputs.Textbox() self.assertEqual(iface.process([10])[0], ["5.0"])
)
self.assertEqual(iface.process([10])[0], ['5.0'])
class TestLabel(unittest.TestCase): class TestLabel(unittest.TestCase):
@ -185,7 +183,7 @@ class TestHighlightedText(unittest.TestCase):
"name": "highlightedtext", "name": "highlightedtext",
"label": None, "label": None,
"show_legend": False, "show_legend": False,
"css": {} "css": {},
}, },
) )
ht = {"pos": "Hello ", "neg": "World"} ht = {"pos": "Hello ", "neg": "World"}
@ -232,10 +230,8 @@ class TestAudio(unittest.TestCase):
) )
) )
self.assertEqual( self.assertEqual(
audio_output.get_template_context(), {"name": "audio", audio_output.get_template_context(),
"label": None, {"name": "audio", "label": None, "source": "upload", "css": {}},
"source": "upload",
"css": {}}
) )
self.assertTrue( self.assertTrue(
audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav") audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav")
@ -363,7 +359,6 @@ class TestDataframe(unittest.TestCase):
"col_width": None, "col_width": None,
"default": [[None, None, None], [None, None, None], [None, None, None]], "default": [[None, None, None], [None, None, None], [None, None, None]],
"name": "dataframe", "name": "dataframe",
}, },
) )
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -373,14 +368,22 @@ class TestDataframe(unittest.TestCase):
to_save = dataframe_output.save_flagged( to_save = dataframe_output.save_flagged(
tmpdirname, "dataframe_output", output, None tmpdirname, "dataframe_output", output, None
) )
self.assertEqual(to_save, json.dumps({ self.assertEqual(
to_save,
json.dumps(
{
"headers": ["num", "prime"], "headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]] "data": [[2, True], [3, True], [4, False]],
})) }
),
)
self.assertEqual( self.assertEqual(
dataframe_output.restore_flagged(tmpdirname, to_save, None), dataframe_output.restore_flagged(tmpdirname, to_save, None),
{"headers": ["num", "prime"], {
"data": [[2, True], [3, True], [4, False]]}) "headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]],
},
)
def test_in_interface(self): def test_in_interface(self):
def check_odd(array): def check_odd(array):
@ -417,12 +420,18 @@ class TestCarousel(unittest.TestCase):
carousel_output.get_template_context(), carousel_output.get_template_context(),
{ {
"components": [ "components": [
{"name": "textbox", "label": None, "default": "", "lines": 1, {
"css": {}, 'placeholder': None} "name": "textbox",
"label": None,
"default": "",
"lines": 1,
"css": {},
"placeholder": None,
}
], ],
"name": "carousel", "name": "carousel",
"label": "Disease", "label": "Disease",
"css": {} "css": {},
}, },
) )
output = carousel_output.postprocess(["Hello World", "Bye World"]) output = carousel_output.postprocess(["Hello World", "Bye World"])
@ -473,13 +482,7 @@ class TestTimeseries(unittest.TestCase):
timeseries_output = gr.outputs.Timeseries(label="Disease") timeseries_output = gr.outputs.Timeseries(label="Disease")
self.assertEqual( self.assertEqual(
timeseries_output.get_template_context(), timeseries_output.get_template_context(),
{ {"x": None, "y": None, "name": "timeseries", "label": "Disease", "css": {}},
"x": None,
"y": None,
"name": "timeseries",
"label": "Disease",
"css": {}
},
) )
data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]} data = {"Name": ["Tom", "nick", "krish", "jack"], "Age": [20, 21, 19, 18]}
df = pd.DataFrame(data) df = pd.DataFrame(data)

View File

@ -37,8 +37,9 @@ class TestRoutes(unittest.TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_predict_route(self): def test_predict_route(self):
response = self.client.post("/api/predict/", response = self.client.post(
json={"data": ["test"], "fn_index": 0}) "/api/predict/", json={"data": ["test"], "fn_index": 0}
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
output = dict(response.json()) output = dict(response.json())
self.assertEqual(output["data"], ["testtest"]) self.assertEqual(output["data"], ["testtest"])