mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
Improvements to State
(#2100)
* state * state fix * variable -> state * fix * added state tests * formatting * fix test * formatting * fix test * added tests for bakcward compatibility * formatting * config fix * additional doc * doc fix * formatting
This commit is contained in:
parent
1ad587834a
commit
4d58ae79b3
@ -21,7 +21,7 @@ with demo:
|
||||
flip_btn = gr.Button("Flip Card").style(full_width=True)
|
||||
with gr.Column(visible=False) as answer_col:
|
||||
back = gr.Textbox(label="Answer")
|
||||
selected_card = gr.Variable()
|
||||
selected_card = gr.State()
|
||||
with gr.Row():
|
||||
correct_btn = gr.Button(
|
||||
"Correct",
|
||||
@ -29,7 +29,7 @@ with demo:
|
||||
incorrect_btn = gr.Button("Incorrect").style(full_width=True)
|
||||
|
||||
with gr.TabItem("Results"):
|
||||
results = gr.Variable(value={})
|
||||
results = gr.State(value={})
|
||||
correct_field = gr.Markdown("# Correct: 0")
|
||||
incorrect_field = gr.Markdown("# Incorrect: 0")
|
||||
gr.Markdown("Card Statistics: ")
|
||||
|
@ -5,11 +5,11 @@ demo = gr.Blocks(css="#btn {color: red}")
|
||||
with demo:
|
||||
default_json = {"a": "a"}
|
||||
|
||||
num = gr.Variable(value=0)
|
||||
num = gr.State(value=0)
|
||||
squared = gr.Number(value=0)
|
||||
btn = gr.Button("Next Square", elem_id="btn").style(rounded=False)
|
||||
|
||||
stats = gr.Variable(value=default_json)
|
||||
stats = gr.State(value=default_json)
|
||||
table = gr.JSON()
|
||||
|
||||
def increase(var, stats_history):
|
||||
|
@ -42,8 +42,8 @@ with gr.Blocks() as Dataframe_demo:
|
||||
with gr.Blocks() as Timeseries_demo:
|
||||
gr.Timeseries()
|
||||
|
||||
with gr.Blocks() as Variable_demo:
|
||||
gr.Variable()
|
||||
with gr.Blocks() as State_demo:
|
||||
gr.State()
|
||||
|
||||
with gr.Blocks() as Button_demo:
|
||||
gr.Button()
|
||||
|
@ -4,7 +4,7 @@ import random
|
||||
secret_word = "gradio"
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
used_letters_var = gr.Variable([])
|
||||
used_letters_var = gr.State([])
|
||||
with gr.Row() as row:
|
||||
with gr.Column():
|
||||
input_letter = gr.Textbox(label="Enter letter")
|
||||
|
@ -17,7 +17,7 @@ from constants import (
|
||||
|
||||
|
||||
demo = gr.Interface(
|
||||
lambda x: x,
|
||||
lambda *args: args[0],
|
||||
inputs=[
|
||||
gr.Textbox(value=lambda: datetime.now(), label="Current Time"),
|
||||
gr.Number(value=lambda: random.random(), label="Ranom Percentage"),
|
||||
@ -60,7 +60,7 @@ demo = gr.Interface(
|
||||
)
|
||||
),
|
||||
gr.Timeseries(value=lambda: os.path.join(file_dir, "time.csv")),
|
||||
gr.Variable(value=lambda: random.choice(string.ascii_lowercase)),
|
||||
gr.State(value=lambda: random.choice(string.ascii_lowercase)),
|
||||
gr.Button(value=lambda: random.choice(["Run", "Go", "predict"])),
|
||||
gr.ColorPicker(value=lambda: random.choice(["#000000", "#ff0000", "#0000FF"])),
|
||||
gr.Label(value=lambda: random.choice(["Pedestrian", "Car", "Cyclist"])),
|
||||
@ -91,7 +91,9 @@ demo = gr.Interface(
|
||||
gr.Plot(value=random_plot),
|
||||
gr.Markdown(value=lambda: f"### {random.choice(['Hello', 'Hi', 'Goodbye!'])}"),
|
||||
],
|
||||
outputs=None,
|
||||
outputs=[
|
||||
gr.State(value=lambda: random.choice(string.ascii_lowercase))
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
with gr.Blocks() as demo:
|
||||
inp = gr.Audio(source="microphone")
|
||||
out = gr.Audio()
|
||||
stream = gr.Variable()
|
||||
stream = gr.State()
|
||||
|
||||
def add_to_stream(audio, instream):
|
||||
if audio is None:
|
||||
|
@ -35,6 +35,7 @@ from gradio.components import (
|
||||
Plot,
|
||||
Radio,
|
||||
Slider,
|
||||
State,
|
||||
StatusTracker,
|
||||
Textbox,
|
||||
TimeSeries,
|
||||
|
@ -2541,10 +2541,10 @@ class Timeseries(Changeable, IOComponent, JSONSerializable):
|
||||
|
||||
|
||||
@document()
|
||||
class Variable(IOComponent, SimpleSerializable):
|
||||
class State(IOComponent, SimpleSerializable):
|
||||
"""
|
||||
Special hidden component that stores session state across runs of the demo by the
|
||||
same user. The value of the Variable is cleared when the user refreshes the page.
|
||||
same user. The value of the State variable is cleared when the user refreshes the page.
|
||||
|
||||
Preprocessing: No preprocessing is performed
|
||||
Postprocessing: No postprocessing is performed
|
||||
@ -2570,6 +2570,16 @@ class Variable(IOComponent, SimpleSerializable):
|
||||
return self
|
||||
|
||||
|
||||
class Variable(State):
|
||||
"""Variable was renamed to State. This class is kept for backwards compatibility."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_block_name(self):
|
||||
return "state"
|
||||
|
||||
|
||||
@document("click", "style")
|
||||
class Button(Clickable, IOComponent, SimpleSerializable):
|
||||
"""
|
||||
|
@ -427,7 +427,7 @@ class Timeseries(components.Timeseries):
|
||||
super().__init__(x=x, y=y, label=label, optional=optional)
|
||||
|
||||
|
||||
class State(components.Variable):
|
||||
class State(components.State):
|
||||
"""
|
||||
Special hidden component that stores state across runs of the interface.
|
||||
Input type: Any
|
||||
@ -445,7 +445,7 @@ class State(components.Variable):
|
||||
optional (bool): this parameter is ignored.
|
||||
"""
|
||||
warnings.warn(
|
||||
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.Variable from gradio.components",
|
||||
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
|
||||
)
|
||||
super().__init__(value=default, label=label)
|
||||
|
||||
|
@ -27,8 +27,8 @@ from gradio.components import (
|
||||
Interpretation,
|
||||
IOComponent,
|
||||
Markdown,
|
||||
State,
|
||||
StatusTracker,
|
||||
Variable,
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
@ -213,17 +213,30 @@ class Interface(Blocks):
|
||||
else:
|
||||
self.cache_examples = cache_examples or False
|
||||
|
||||
if "state" in inputs or "state" in outputs:
|
||||
state_input_count = len([i for i in inputs if i == "state"])
|
||||
state_output_count = len([o for o in outputs if o == "state"])
|
||||
if state_input_count != 1 or state_output_count != 1:
|
||||
raise ValueError(
|
||||
"If using 'state', there must be exactly one state input and one state output."
|
||||
)
|
||||
default = utils.get_default_args(fn)[inputs.index("state")]
|
||||
state_variable = Variable(value=default)
|
||||
inputs[inputs.index("state")] = state_variable
|
||||
outputs[outputs.index("state")] = state_variable
|
||||
state_input_indexes = [
|
||||
idx for idx, i in enumerate(inputs) if i == "state" or isinstance(i, State)
|
||||
]
|
||||
state_output_indexes = [
|
||||
idx for idx, o in enumerate(outputs) if o == "state" or isinstance(o, State)
|
||||
]
|
||||
|
||||
if len(state_input_indexes) == 0 and len(state_output_indexes) == 0:
|
||||
pass
|
||||
elif len(state_input_indexes) != 1 or len(state_output_indexes) != 1:
|
||||
raise ValueError(
|
||||
"If using 'state', there must be exactly one state input and one state output."
|
||||
)
|
||||
else:
|
||||
state_input_index = state_input_indexes[0]
|
||||
state_output_index = state_output_indexes[0]
|
||||
if inputs[state_input_index] == "state":
|
||||
default = utils.get_default_args(fn)[state_input_index]
|
||||
state_variable = State(value=default)
|
||||
else:
|
||||
state_variable = inputs[state_input_index]
|
||||
|
||||
inputs[state_input_index] = state_variable
|
||||
outputs[state_output_index] = state_variable
|
||||
|
||||
if cache_examples:
|
||||
warnings.warn(
|
||||
@ -240,9 +253,7 @@ class Interface(Blocks):
|
||||
]
|
||||
|
||||
for component in self.input_components + self.output_components:
|
||||
if not (
|
||||
isinstance(component, IOComponent) or isinstance(component, Variable)
|
||||
):
|
||||
if not (isinstance(component, IOComponent)):
|
||||
raise ValueError(
|
||||
f"{component} is not a valid input/output component for Interface."
|
||||
)
|
||||
@ -607,10 +618,10 @@ class Interface(Blocks):
|
||||
|
||||
if self.examples:
|
||||
non_state_inputs = [
|
||||
c for c in self.input_components if not isinstance(c, Variable)
|
||||
c for c in self.input_components if not isinstance(c, State)
|
||||
]
|
||||
non_state_outputs = [
|
||||
c for c in self.output_components if not isinstance(c, Variable)
|
||||
c for c in self.output_components if not isinstance(c, State)
|
||||
]
|
||||
self.examples_handler = Examples(
|
||||
examples=examples,
|
||||
|
@ -158,7 +158,7 @@ class Timeseries(components.Timeseries):
|
||||
super().__init__(x=x, y=y, label=label)
|
||||
|
||||
|
||||
class State(components.Variable):
|
||||
class State(components.State):
|
||||
"""
|
||||
Special hidden component that stores state across runs of the interface.
|
||||
Output type: Any
|
||||
@ -170,7 +170,7 @@ class State(components.Variable):
|
||||
label (str): component name in interface (not used).
|
||||
"""
|
||||
warnings.warn(
|
||||
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
|
||||
"Usage of gradio.outputs is deprecated, and will not be supported in the future, please import this component as gr.State() from gradio.components",
|
||||
)
|
||||
super().__init__(label=label)
|
||||
|
||||
|
@ -198,7 +198,7 @@ def launch_counter() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def get_default_args(func: Callable) -> Dict[str, Any]:
|
||||
def get_default_args(func: Callable) -> List[Any]:
|
||||
signature = inspect.signature(func)
|
||||
return [
|
||||
v.default if v.default is not inspect.Parameter.empty else None
|
||||
|
@ -23,4 +23,4 @@ $demo_chatbot_demo
|
||||
|
||||
Notice how the state persists across submits within each page, but if you load this demo in another tab (or refresh the page), the demos will not share chat history.
|
||||
|
||||
The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead.
|
||||
The default value of `state` is None. If you pass a default value to the state parameter of the function, it is used as the default value of the state instead. The `Interface` class only supports a single input and outputs state variable, though it can be a list with multiple elements. For more complex use cases, you can use Blocks, [which supports multiple `State` variables](/state_in_blocks/).
|
@ -8,8 +8,8 @@ Global state in Blocks works the same as in Interface. Any variable created outs
|
||||
|
||||
Gradio supports session **state**, where data persists across multiple submits within a page session, in Blocks apps as well. To reiterate, session data is *not* shared between different users of your model. To store data in a session state, you need to do three things:
|
||||
|
||||
1. Create a `gr.Variable()` object. If there is a default value to this stateful object, pass that into the constructor.
|
||||
2. In the event listener, put the `Variable` object as an input and output.
|
||||
1. Create a `gr.State()` object. If there is a default value to this stateful object, pass that into the constructor.
|
||||
2. In the event listener, put the `State` object as an input and output.
|
||||
3. In the event listener function, add the variable to the input parameters and the return value.
|
||||
|
||||
Let's take a look at a game of hangman.
|
||||
@ -19,11 +19,11 @@ $demo_hangman
|
||||
|
||||
Let's see how we do each of the 3 steps listed above in this game:
|
||||
|
||||
1. We store the used letters in `used_letters_var`. In the constructor of `Variable`, we set the initial value of this to `[]`, an empty list.
|
||||
1. We store the used letters in `used_letters_var`. In the constructor of `State`, we set the initial value of this to `[]`, an empty list.
|
||||
2. In `btn.click()`, we have a reference to `used_letters_var` in both the inputs and outputs.
|
||||
3. In `guess_letter`, we pass the value of this `Variable` to `used_letters`, and then return an updated value of this `Variable` in the return statement.
|
||||
3. In `guess_letter`, we pass the value of this `State` to `used_letters`, and then return an updated value of this `State` in the return statement.
|
||||
|
||||
With more complex apps, you will likely have many Variables storing session state in a single Blocks app.
|
||||
With more complex apps, you will likely have many State variables storing session state in a single Blocks app.
|
||||
|
||||
|
||||
|
||||
|
@ -232,7 +232,7 @@ def test_slider_random_value_config():
|
||||
|
||||
|
||||
def test_io_components_attach_load_events_when_value_is_fn(io_components):
|
||||
|
||||
io_components = [comp for comp in io_components if not (comp == gr.State)]
|
||||
interface = gr.Interface(
|
||||
lambda *args: None,
|
||||
inputs=[comp(value=lambda: None) for comp in io_components],
|
||||
@ -247,7 +247,7 @@ def test_io_components_attach_load_events_when_value_is_fn(io_components):
|
||||
|
||||
def test_blocks_do_not_filter_none_values_from_updates(io_components):
|
||||
|
||||
io_components = [c() for c in io_components if c not in [gr.Variable, gr.Button]]
|
||||
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
component.render()
|
||||
|
@ -1736,5 +1736,48 @@ def test_video_postprocess_converts_to_playable_format(test_file_dir):
|
||||
assert processing_utils.video_is_playable(str(full_path_to_output))
|
||||
|
||||
|
||||
class TestState:
|
||||
def test_as_component(self):
|
||||
state = gr.State(value=5)
|
||||
assert state.preprocess(10) == 10
|
||||
assert state.preprocess("abc") == "abc"
|
||||
assert state.stateful
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_interface(self):
|
||||
def test(x, y=" def"):
|
||||
return (x + y, x + y)
|
||||
|
||||
io = gr.Interface(test, ["text", "state"], ["text", "state"])
|
||||
result = await io.call_function(0, ["abc"])
|
||||
assert result[0][0] == "abc def"
|
||||
result = await io.call_function(0, ["abc", result[0][0]])
|
||||
assert result[0][0] == "abcabc def"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_blocks(self):
|
||||
with gr.Blocks() as demo:
|
||||
score = gr.State()
|
||||
btn = gr.Button()
|
||||
btn.click(lambda x: x + 1, score, score)
|
||||
|
||||
result = await demo.call_function(0, [0])
|
||||
assert result[0] == 1
|
||||
result = await demo.call_function(0, [result[0]])
|
||||
assert result[0] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_variable_for_backwards_compatibility(self):
|
||||
with gr.Blocks() as demo:
|
||||
score = gr.Variable()
|
||||
btn = gr.Button()
|
||||
btn.click(lambda x: x + 1, score, score)
|
||||
|
||||
result = await demo.call_function(0, [0])
|
||||
assert result[0] == 1
|
||||
result = await demo.call_function(0, [result[0]])
|
||||
assert result[0] == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
2
ui/packages/app/src/components/State/index.ts
Normal file
2
ui/packages/app/src/components/State/index.ts
Normal file
@ -0,0 +1,2 @@
|
||||
export { default as Component } from "./State.svelte";
|
||||
export const modes = ["static"];
|
@ -1,2 +0,0 @@
|
||||
export { default as Component } from "./Variable.svelte";
|
||||
export const modes = ["static"];
|
@ -28,11 +28,11 @@ export const component_map = {
|
||||
radio: () => import("./Radio"),
|
||||
row: () => import("./Row"),
|
||||
slider: () => import("./Slider"),
|
||||
state: () => import("./State"),
|
||||
statustracker: () => import("./StatusTracker"),
|
||||
tabs: () => import("./Tabs"),
|
||||
tabitem: () => import("./TabItem"),
|
||||
textbox: () => import("./Textbox"),
|
||||
timeseries: () => import("./TimeSeries"),
|
||||
variable: () => import("./Variable"),
|
||||
video: () => import("./Video")
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user