test state

This commit is contained in:
Abubakar Abid 2022-03-04 00:24:17 -05:00
parent a38a428072
commit fd034ee29f
2 changed files with 26 additions and 10 deletions

View File

@ -196,16 +196,16 @@ class Interface(Launchable):
state.default = default
self.state_default = state.default
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
state_return_index = [
isinstance(i, o_State) for i in self.output_components
].index(True)
self.state_return_index = state_return_index
else:
raise ValueError(
"For a stateful interface, there must be exactly one State"
" input component and one State output component."
)
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
state_return_index = [
isinstance(i, o_State) for i in self.output_components
].index(True)
self.state_return_index = state_return_index
else:
raise ValueError(
"For a stateful interface, there must be exactly one State"
" input component and one State output component."
)
if (
interpretation is None

View File

@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase):
self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output)
def test_state(self):
def predict(input, history=""):
history += input
return history, history
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": ["test", None]})
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post("/api/predict/", json={"data": ["test", None]})
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
def test_queue_push_route(self):
queueing.push = mock.MagicMock(return_value=(None, None))
response = self.client.post(