mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
test state
This commit is contained in:
parent
a38a428072
commit
fd034ee29f
@ -196,16 +196,16 @@ class Interface(Launchable):
|
||||
state.default = default
|
||||
self.state_default = state.default
|
||||
|
||||
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
|
||||
state_return_index = [
|
||||
isinstance(i, o_State) for i in self.output_components
|
||||
].index(True)
|
||||
self.state_return_index = state_return_index
|
||||
else:
|
||||
raise ValueError(
|
||||
"For a stateful interface, there must be exactly one State"
|
||||
" input component and one State output component."
|
||||
)
|
||||
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
|
||||
state_return_index = [
|
||||
isinstance(i, o_State) for i in self.output_components
|
||||
].index(True)
|
||||
self.state_return_index = state_return_index
|
||||
else:
|
||||
raise ValueError(
|
||||
"For a stateful interface, there must be exactly one State"
|
||||
" input component and one State output component."
|
||||
)
|
||||
|
||||
if (
|
||||
interpretation is None
|
||||
|
@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase):
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
def test_state(self):
|
||||
def predict(input, history=""):
|
||||
history += input
|
||||
return history, history
|
||||
|
||||
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/predict/", json={"data": ["test", None]})
|
||||
output = dict(response.json())
|
||||
print("output", output)
|
||||
self.assertEqual(output["data"], ["test", None])
|
||||
response = client.post("/api/predict/", json={"data": ["test", None]})
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["testtest", None])
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.push = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post(
|
||||
|
Loading…
x
Reference in New Issue
Block a user