Improvements to State (#2100)

* state

* state fix

* variable -> state

* fix

* added state tests

* formatting

* fix test

* formatting

* fix test

* added tests for bakcward compatibility

* formatting

* config fix

* additional doc

* doc fix

* formatting
This commit is contained in:
Abubakar Abid 2022-08-29 09:53:05 -07:00 committed by GitHub
parent 1ad587834a
commit 4d58ae79b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 113 additions and 46 deletions

View File

@ -21,7 +21,7 @@ with demo:
flip_btn = gr.Button("Flip Card").style(full_width=True)
with gr.Column(visible=False) as answer_col:
back = gr.Textbox(label="Answer")
selected_card = gr.Variable()
selected_card = gr.State()
with gr.Row():
correct_btn = gr.Button(
"Correct",
@ -29,7 +29,7 @@ with demo:
incorrect_btn = gr.Button("Incorrect").style(full_width=True)
with gr.TabItem("Results"):
results = gr.Variable(value={})
results = gr.State(value={})
correct_field = gr.Markdown("# Correct: 0")
incorrect_field = gr.Markdown("# Incorrect: 0")
gr.Markdown("Card Statistics: ")

View File

@ -5,11 +5,11 @@ demo = gr.Blocks(css="#btn {color: red}")
with demo:
default_json = {"a": "a"}
num = gr.Variable(value=0)
num = gr.State(value=0)
squared = gr.Number(value=0)
btn = gr.Button("Next Square", elem_id="btn").style(rounded=False)
stats = gr.Variable(value=default_json)
stats = gr.State(value=default_json)
table = gr.JSON()
def increase(var, stats_history):

View File

@ -42,8 +42,8 @@ with gr.Blocks() as Dataframe_demo:
with gr.Blocks() as Timeseries_demo:
gr.Timeseries()
with gr.Blocks() as Variable_demo:
gr.Variable()
with gr.Blocks() as State_demo:
gr.State()
with gr.Blocks() as Button_demo:
gr.Button()

View File

@ -4,7 +4,7 @@ import random
secret_word = "gradio"
with gr.Blocks() as demo:
used_letters_var = gr.Variable([])
used_letters_var = gr.State([])
with gr.Row() as row:
with gr.Column():
input_letter = gr.Textbox(label="Enter letter")

View File

@ -17,7 +17,7 @@ from constants import (
demo = gr.Interface(
lambda x: x,
lambda *args: args[0],
inputs=[
gr.Textbox(value=lambda: datetime.now(), label="Current Time"),
gr.Number(value=lambda: random.random(), label="Ranom Percentage"),
@ -60,7 +60,7 @@ demo = gr.Interface(
)
),
gr.Timeseries(value=lambda: os.path.join(file_dir, "time.csv")),
gr.Variable(value=lambda: random.choice(string.ascii_lowercase)),
gr.State(value=lambda: random.choice(string.ascii_lowercase)),
gr.Button(value=lambda: random.choice(["Run", "Go", "predict"])),
gr.ColorPicker(value=lambda: random.choice(["#000000", "#ff0000", "#0000FF"])),
gr.Label(value=lambda: random.choice(["Pedestrian", "Car", "Cyclist"])),
@ -91,7 +91,9 @@ demo = gr.Interface(
gr.Plot(value=random_plot),
gr.Markdown(value=lambda: f"### {random.choice(['Hello', 'Hi', 'Goodbye!'])}"),
],
outputs=None,
outputs=[
gr.State(value=lambda: random.choice(string.ascii_lowercase))
],
)
if __name__ == "__main__":

View File

@ -4,7 +4,7 @@ import numpy as np
with gr.Blocks() as demo:
inp = gr.Audio(source="microphone")
out = gr.Audio()
stream = gr.Variable()
stream = gr.State()
def add_to_stream(audio, instream):
if audio is None:

View File

@ -35,6 +35,7 @@ from gradio.components import (
Plot,
Radio,
Slider,
State,
StatusTracker,
Textbox,
TimeSeries,

View File

@ -2541,10 +2541,10 @@ class Timeseries(Changeable, IOComponent, JSONSerializable):
@document()
class Variable(IOComponent, SimpleSerializable):
class State(IOComponent, SimpleSerializable):
"""
Special hidden component that stores session state across runs of the demo by the
same user. The value of the Variable is cleared when the user refreshes the page.
same user. The value of the State variable is cleared when the user refreshes the page.
Preprocessing: No preprocessing is performed
Postprocessing: No postprocessing is performed
@ -2570,6 +2570,16 @@ class Variable(IOComponent, SimpleSerializable):
return self
class Variable(State):
"""Variable was renamed to State. This class is kept for backwards compatibility."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_block_name(self):
return "state"
@document("click", "style")
class Button(Clickable, IOComponent, SimpleSerializable):
"""

View File

@ -427,7 +427,7 @@ class Timeseries(components.Timeseries):
super().__init__(x=x, y=y, label=label, optional=optional)
class State(components.Variable):
class State(components.State):
"""
Special hidden component that stores state across runs of the interface.
Input type: Any
@ -445,7 +445,7 @@ class State(components.Variable):
optional (bool): this parameter is ignored.
"""
warnings.warn(
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.Variable from gradio.components",
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
)
super().__init__(value=default, label=label)

View File

@ -27,8 +27,8 @@ from gradio.components import (
Interpretation,
IOComponent,
Markdown,
State,
StatusTracker,
Variable,
get_component_instance,
)
from gradio.documentation import document, set_documentation_group
@ -213,17 +213,30 @@ class Interface(Blocks):
else:
self.cache_examples = cache_examples or False
if "state" in inputs or "state" in outputs:
state_input_count = len([i for i in inputs if i == "state"])
state_output_count = len([o for o in outputs if o == "state"])
if state_input_count != 1 or state_output_count != 1:
raise ValueError(
"If using 'state', there must be exactly one state input and one state output."
)
default = utils.get_default_args(fn)[inputs.index("state")]
state_variable = Variable(value=default)
inputs[inputs.index("state")] = state_variable
outputs[outputs.index("state")] = state_variable
state_input_indexes = [
idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State)
]
state_output_indexes = [
idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State)
]
if len(state_input_indexes) == 0 and len(state_output_indexes) == 0:
pass
elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1:
raise ValueError(
"If using 'state', there must be exactly one state input and one state output."
)
else:
state_input_index = state_input_indexes[0]
state_output_index = state_output_indexes[0]
if inputs[state_input_index] == "state":
default = utils.get_default_args(fn)[state_input_index]
state_variable = State(value=default)
else:
state_variable = inputs[state_input_index]
inputs[state_input_index] = state_variable
outputs[state_output_index] = state_variable
if cache_examples:
warnings.warn(
@ -240,9 +253,7 @@ class Interface(Blocks):
]
for component in self.input_components + self.output_components:
if not (
isinstance(component, IOComponent) or isinstance(component, Variable)
):
if not (isinstance(component, IOComponent)):
raise ValueError(
f"{component} is not a valid input/output component for Interface."
)
@ -607,10 +618,10 @@ class Interface(Blocks):
if self.examples:
non_state_inputs = [
c for c in self.input_components if not isinstance(c, Variable)
c for c in self.input_components if not isinstance(c, State)
]
non_state_outputs = [
c for c in self.output_components if not isinstance(c, Variable)
c for c in self.output_components if not isinstance(c, State)
]
self.examples_handler = Examples(
examples=examples,

View File

@ -158,7 +158,7 @@ class Timeseries(components.Timeseries):
super().__init__(x=x, y=y, label=label)
class State(components.Variable):
class State(components.State):
"""
Special hidden component that stores state across runs of the interface.
Output type: Any
@ -170,7 +170,7 @@ class State(components.Variable):
label (str): component name in interface (not used).
"""
warnings.warn(
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
)
super().__init__(label=label)

View File

@ -198,7 +198,7 @@ def launch_counter() -> None:
pass
def get_default_args(func: Callable) -> Dict[str, Any]:
def get_default_args(func: Callable) -> List[Any]:
signature = inspect.signature(func)
return [
v.default if v.default is not inspect.Parameter.empty else None

View File

@ -23,4 +23,4 @@ $demo_chatbot_demo
Notice how the state persists across submits within each page, but if you load this demo in another tab (or refresh the page), the demos will not share chat history.
The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead.
The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead. The `Interface` class only supports a single input and outputs state variable, though it can be a list with multiple elements. For more complex use cases, you can use Blocks, [which supports multiple `State` variables](/state_in_blocks/).

View File

@ -8,8 +8,8 @@ Global state in Blocks works the same as in Interface. Any variable created outs
Gradio supports session **state**, where data persists across multiple submits within a page session, in Blocks apps as well. To reiterate, session data is *not* shared between different users of your model. To store data in a session state, you need to do three things:
1. Create a `gr.Variable()` object. If there is a default value to this stateful object, pass that into the constructor.
2. In the event listener, put the `Variable` object as an input and output.
1. Create a `gr.State()` object. If there is a default value to this stateful object, pass that into the constructor.
2. In the event listener, put the `State` object as an input and output.
3. In the event listener function, add the variable to the input parameters and the return value.
Let's take a look at a game of hangman.
@ -19,11 +19,11 @@ $demo_hangman
Let's see how we do each of the 3 steps listed above in this game:
1. We store the used letters in `used_letters_var`. In the constructor of `Variable`, we set the initial value of this to `[]`, an empty list.
1. We store the used letters in `used_letters_var`. In the constructor of `State`, we set the initial value of this to `[]`, an empty list.
2. In `btn.click()`, we have a reference to `used_letters_var` in both the inputs and outputs.
3. In `guess_letter`, we pass the value of this `Variable` to `used_letters`, and then return an updated value of this `Variable` in the return statement.
3. In `guess_letter`, we pass the value of this `State` to `used_letters`, and then return an updated value of this `State` in the return statement.
With more complex apps, you will likely have many Variables storing session state in a single Blocks app.
With more complex apps, you will likely have many State variables storing session state in a single Blocks app.

View File

@ -232,7 +232,7 @@ def test_slider_random_value_config():
def test_io_components_attach_load_events_when_value_is_fn(io_components):
io_components = [comp for comp in io_components if not (comp == gr.State)]
interface = gr.Interface(
lambda *args: None,
inputs=[comp(value=lambda: None) for comp in io_components],
@ -247,7 +247,7 @@ def test_io_components_attach_load_events_when_value_is_fn(io_components):
def test_blocks_do_not_filter_none_values_from_updates(io_components):
io_components = [c() for c in io_components if c not in [gr.Variable, gr.Button]]
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
with gr.Blocks() as demo:
for component in io_components:
component.render()

View File

@ -1736,5 +1736,48 @@ def test_video_postprocess_converts_to_playable_format(test_file_dir):
assert processing_utils.video_is_playable(str(full_path_to_output))
class TestState:
def test_as_component(self):
state = gr.State(value=5)
assert state.preprocess(10) == 10
assert state.preprocess("abc") == "abc"
assert state.stateful
@pytest.mark.asyncio
async def test_in_interface(self):
def test(x, y=" def"):
return (x + y, x + y)
io = gr.Interface(test, ["text", "state"], ["text", "state"])
result = await io.call_function(0, ["abc"])
assert result[0][0] == "abc def"
result = await io.call_function(0, ["abc", result[0][0]])
assert result[0][0] == "abcabc def"
@pytest.mark.asyncio
async def test_in_blocks(self):
with gr.Blocks() as demo:
score = gr.State()
btn = gr.Button()
btn.click(lambda x: x + 1, score, score)
result = await demo.call_function(0, [0])
assert result[0] == 1
result = await demo.call_function(0, [result[0]])
assert result[0] == 2
@pytest.mark.asyncio
async def test_variable_for_backwards_compatibility(self):
with gr.Blocks() as demo:
score = gr.Variable()
btn = gr.Button()
btn.click(lambda x: x + 1, score, score)
result = await demo.call_function(0, [0])
assert result[0] == 1
result = await demo.call_function(0, [result[0]])
assert result[0] == 2
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,2 @@
export { default as Component } from "./State.svelte";
export const modes = ["static"];

View File

@ -1,2 +0,0 @@
export { default as Component } from "./Variable.svelte";
export const modes = ["static"];

View File

@ -28,11 +28,11 @@ export const component_map = {
radio: () => import("./Radio"),
row: () => import("./Row"),
slider: () => import("./Slider"),
state: () => import("./State"),
statustracker: () => import("./StatusTracker"),
tabs: () => import("./Tabs"),
tabitem: () => import("./TabItem"),
textbox: () => import("./Textbox"),
timeseries: () => import("./TimeSeries"),
variable: () => import("./Variable"),
video: () => import("./Video")
};