State and variables (#977)

Restore state and add variables component

Co-authored-by: Ali Abid <aliabid94@gmail.com>
This commit is contained in:
aliabid94 2022-04-12 18:41:13 -07:00 committed by GitHub
parent ceea8ce3ca
commit 1d3cb510bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 274 additions and 105 deletions

View File

@ -0,0 +1,91 @@
import random
import gradio as gr
demo = gr.Blocks()
with demo:
gr.Markdown(
"Load the flashcards in the table below, then use the Practice tab to practice."
)
with gr.Tabs():
with gr.TabItem("Word Bank"):
flashcards_table = gr.Dataframe(headers=["front", "back"], type="array")
with gr.TabItem("Practice"):
with gr.Row():
front = gr.Textbox()
answer_row = gr.Row(visible=False)
with answer_row:
back = gr.Textbox()
with gr.Row():
new_btn = gr.Button("New Card")
flip_btn = gr.Button("Flip Card")
selected_card = gr.Variable()
feedback_row = gr.Row(visible=False)
with feedback_row:
correct_btn = gr.Button(
"Correct",
css={"background-color": "lightgreen", "color": "green"},
)
incorrect_btn = gr.Button(
"Incorrect", css={"background-color": "pink", "color": "red"}
)
with gr.TabItem("Results"):
results = gr.Variable(default_value={})
correct_field = gr.Markdown("# Correct: 0")
incorrect_field = gr.Markdown("# Incorrect: 0")
gr.Markdown("Card Statistics: ")
results_table = gr.Dataframe(headers=["Card", "Correct", "Incorrect"])
def load_new_card(flashcards):
card = random.choice(flashcards)
return card, card[0], False, False
new_btn.click(
load_new_card,
[flashcards_table],
[selected_card, front, answer_row, feedback_row],
)
def flip_card(card):
return card[1], True, True
flip_btn.click(flip_card, [selected_card], [back, answer_row, feedback_row])
def mark_correct(card, results):
if card[0] not in results:
results[card[0]] = [0, 0]
results[card[0]][0] += 1
correct_count = sum(result[0] for result in results.values())
return (
results,
f"# Correct: {correct_count}",
[[front, scores[0], scores[1]] for front, scores in results.items()],
)
def mark_incorrect(card, results):
if card[0] not in results:
results[card[0]] = [0, 0]
results[card[0]][1] += 1
incorrect_count = sum(result[1] for result in results.values())
return (
results,
f"# Inorrect: {incorrect_count}",
[[front, scores[0], scores[1]] for front, scores in results.items()],
)
correct_btn.click(
mark_correct,
[selected_card, results],
[results, correct_field, results_table],
)
incorrect_btn.click(
mark_incorrect,
[selected_card, results],
[results, incorrect_field, results_table],
)
if __name__ == "__main__":
demo.launch()

View File

@ -1,65 +1,76 @@
# A Blocks implementation of https://erlj.notion.site/Neural-Instrument-Cloning-from-very-few-samples-2cf41d8b630842ee8c7eb55036a1bfd6
# Needs to be run from the demo\blocks_neural_instrument_coding folder
import gradio as gr
from gradio.components import Markdown as m
import datetime
import random
import gradio as gr
from gradio.components import Markdown as m
def get_time():
now = datetime.datetime.now()
return now.strftime("%m/%d/%Y, %H:%M:%S")
def generate_recording():
return random.choice(["new-sax-1.mp3", "new-sax-1.wav"])
def reconstruct(audio):
return random.choice(["new-sax-1.mp3", "new-sax-1.wav"])
io1 = gr.Interface(
lambda x,y,z:"sax.wav",
[gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)")
],
gr.Audio()
lambda x, y, z: "sax.wav",
[
gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)"),
],
gr.Audio(),
)
io2 = gr.Interface(
lambda x,y,z:"flute.wav",
[gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)")
],
gr.Audio()
lambda x, y, z: "flute.wav",
[
gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)"),
],
gr.Audio(),
)
io3 = gr.Interface(
lambda x,y,z:"trombone.wav",
[gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)")
],
gr.Audio()
lambda x, y, z: "trombone.wav",
[
gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)"),
],
gr.Audio(),
)
io4 = gr.Interface(
lambda x,y,z:"sax2.wav",
[gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)")
],
gr.Audio()
lambda x, y, z: "sax2.wav",
[
gr.Slider(label="pitch"),
gr.Slider(label="loudness"),
gr.Audio(label="base audio file (optional)"),
],
gr.Audio(),
)
demo = gr.Blocks()
with demo.clear():
m("""
m(
"""
## Neural Instrument Cloning from Very Few Samples
<center><img src="https://media.istockphoto.com/photos/brass-trombone-picture-id490455809?k=20&m=490455809&s=612x612&w=0&h=l9KJvH_25z0QTLggHrcH_MsR4gPLH7uXwDPUAZ_C5zk=" width="400px"></center>"""
)
m("""
m(
"""
This Blocks implementation is an adaptation [a report written](https://erlj.notion.site/Neural-Instrument-Cloning-from-very-few-samples-2cf41d8b630842ee8c7eb55036a1bfd6) by Nicolas Jonason and Bob L.T. Sturm.
I've implemented it in Blocks to show off some cool features, such as embedding live ML demos. More on that ahead...
@ -69,54 +80,64 @@ with demo.clear():
### Audio Examples
Here are some **real** 16 second saxophone recordings:
""")
"""
)
gr.Audio("sax.wav", label="Here is a real 16 second saxophone recording:")
gr.Audio("sax.wav")
m(
"""\n
Here is a **generated** saxophone recordings:"""
)
a = gr.Audio("new-sax.wav")
gr.Button("Generate a new saxophone recording")
m("""
m(
"""
### Inputs to the model
The inputs to the model are:
* pitch
* loudness
* base audio file
""")
m("""
"""
)
m(
"""
Try the model live!
""")
gr.TabbedInterface([io1, io2, io3, io4], ["Saxophone", "Flute", "Trombone", "Another Saxophone"])
m("""
"""
)
gr.TabbedInterface(
[io1, io2, io3, io4], ["Saxophone", "Flute", "Trombone", "Another Saxophone"]
)
m(
"""
### Using the model for cloning
You can also use this model a different way, to simply clone the audio file and reconstruct it
using machine learning. Here, we'll show a demo of that below:
""")
"""
)
a2 = gr.Audio()
a2.change(reconstruct, a2, a2)
m("""
m(
"""
Thanks for reading this! As you may have realized, all of the "models" in this demo are fake. They are just designed to show you what is possible using Blocks 🤗.
For details of the model, read the [original report here](https://erlj.notion.site/Neural-Instrument-Cloning-from-very-few-samples-2cf41d8b630842ee8c7eb55036a1bfd6).
*Details for nerds*: this report was "launched" on:
""")
"""
)
t = gr.Textbox(label="timestamp")
demo.load(get_time, [], t)
if __name__ == "__main__":
demo.launch()
demo.launch()

View File

@ -0,0 +1,16 @@
import gradio as gr
test = gr.Blocks()
with test:
num = gr.Variable(default_value=0)
squared = gr.Number(default_value=0)
btn = gr.Button("Next Square")
def increase(var):
var += 1
return var, var**2
btn.click(increase, [num], [num, squared])
test.launch()

View File

@ -1,4 +1,4 @@
Metadata-Version: 1.0
Metadata-Version: 2.1
Name: gradio
Version: 2.7.0b70
Summary: Python library for easily interacting with trained machine learning models
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid, Ali Abid, Ali Abdalla, Dawood Khan, Ahsen Khaliq, Pete Allen, Ömer Faruk Özdemir
Author-email: team@gradio.app
License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN
License-File: LICENSE
UNKNOWN

View File

@ -1,5 +1,5 @@
aiohttp
analytics-python
aiohttp
fastapi
ffmpy
markdown-it-py[linkify,plugins]
@ -10,7 +10,7 @@ pandas
paramiko
pillow
pycryptodome
pydub
python-multipart
pydub
requests
uvicorn

View File

@ -25,9 +25,9 @@ from gradio.components import (
Number,
Radio,
Slider,
State,
Textbox,
Timeseries,
Variable,
Video,
)
from gradio.flagging import (

View File

@ -208,30 +208,54 @@ class Blocks(BlockContext):
self._id = Context.id
Context.id += 1
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
def process_api(
self,
data: Dict[str, Any],
username: str = None,
state: Optional[Dict[int, any]] = None,
) -> Dict[str, Any]:
"""
Processes API calls from the frontend.
Parameters:
data: data recieved from the frontend
username: name of user if authentication is set up
state: data stored from stateful components for session
Returns: None
"""
raw_input = data["data"]
fn_index = data["fn_index"]
fn, preprocess, postprocess = self.fns[fn_index]
dependency = self.dependencies[fn_index]
if preprocess:
processed_input = [
self.blocks[input_id].preprocess(raw_input[i])
for i, input_id in enumerate(dependency["inputs"])
]
processed_input = []
for i, input_id in enumerate(dependency["inputs"]):
block = self.blocks[input_id]
if getattr(block, "stateful", False):
processed_input.append(state.get(input_id))
else:
processed_input.append(block.preprocess(raw_input[i]))
predictions = fn(*processed_input)
else:
predictions = fn(*raw_input)
if len(dependency["outputs"]) == 1:
predictions = (predictions,)
if postprocess:
predictions = [
self.blocks[output_id].postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_id in enumerate(dependency["outputs"])
]
return {"data": predictions}
output = []
for i, output_id in enumerate(dependency["outputs"]):
block = self.blocks[output_id]
if getattr(block, "stateful", False):
state[output_id] = predictions[i]
output.append(None)
else:
output.append(
block.postprocess(predictions[i])
if predictions[i] is not None
else None
)
else:
output = predictions
return {"data": output}
def get_template_context(self):
return {"type": "column"}

View File

@ -2267,7 +2267,7 @@ class Timeseries(Component):
self.set_event_trigger("change", fn, inputs, outputs)
class State(Component):
class Variable(Component):
"""
Special hidden component that stores state across runs of the interface.
@ -2279,9 +2279,6 @@ class State(Component):
def __init__(
self,
default_value: Any = None,
*,
label: Optional[str] = None,
css: Optional[Dict] = None,
**kwargs,
):
"""
@ -2290,7 +2287,8 @@ class State(Component):
label (str): component name in interface (not used).
"""
self.default_value = default_value
super().__init__(label=label, css=css, **kwargs)
self.stateful = True
super().__init__(**kwargs)
def get_template_context(self):
return {"default_value": self.default_value, **super().get_template_context()}

View File

@ -19,9 +19,9 @@ from gradio.components import Image as C_Image
from gradio.components import Number as C_Number
from gradio.components import Radio as C_Radio
from gradio.components import Slider as C_Slider
from gradio.components import State as C_State
from gradio.components import Textbox as C_Textbox
from gradio.components import Timeseries as C_Timeseries
from gradio.components import Variable as C_Variable
from gradio.components import Video as C_Video
@ -465,7 +465,7 @@ class Timeseries(C_Timeseries):
super().__init__(x=x, y=y, label=label, optional=optional)
class State(C_State):
class State(C_Variable):
"""
Special hidden component that stores state across runs of the interface.
Input type: Any
@ -476,7 +476,6 @@ class State(C_State):
self,
label: str = None,
default: Any = None,
optional: bool = False,
):
"""
Parameters:
@ -485,7 +484,7 @@ class State(C_State):
optional (bool): this parameter is ignored.
"""
warnings.warn(
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import your components from gradio.components",
"Usage of gradio.inputs is deprecated, and will not be supported in the future, please import this component as gr.Variable from gradio.components",
DeprecationWarning,
)
super().__init__(default_value=default, label=label, optional=optional)
super().__init__(default_value=default, label=label)

View File

@ -27,6 +27,7 @@ from gradio.components import (
Dataset,
Interpretation,
Markdown,
Variable,
get_component_instance,
)
from gradio.external import load_from_pipeline, load_interface # type: ignore
@ -178,27 +179,23 @@ class Interface(Blocks):
if not isinstance(outputs, list):
outputs = [outputs]
if "state" in inputs or "state" in outputs:
state_input_count = len([i for i in inputs if i == "state"])
state_output_count = len([o for o in outputs if o == "state"])
if state_input_count != 1 or state_output_count != 1:
raise ValueError(
"If using 'state', there must be exactly one state input and one state output."
)
default = utils.get_default_args(fn[0])[inputs.index("state")]
state_variable = Variable(default_value=default)
inputs[inputs.index("state")] = state_variable
outputs[outputs.index("state")] = state_variable
self.input_components = [get_component_instance(i) for i in inputs]
self.output_components = [get_component_instance(o) for o in outputs]
if repeat_outputs_per_model:
self.output_components *= len(fn)
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
raise ValueError("Only one input component can be State.")
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
raise ValueError("Only one output component can be State.")
if sum(isinstance(i, i_State) for i in self.input_components) == 1:
if len(fn) > 1:
raise ValueError("State cannot be used with multiple functions.")
state_param_index = [
isinstance(i, i_State) for i in self.input_components
].index(True)
state: i_State = self.input_components[state_param_index]
if state.default_value is None:
default = utils.get_default_args(fn[0])[state_param_index]
state.default_value = default
if (
interpretation is None
or isinstance(interpretation, list)

View File

@ -20,9 +20,9 @@ from gradio.components import File as C_File
from gradio.components import HighlightedText as C_HighlightedText
from gradio.components import Image as C_Image
from gradio.components import Label as C_Label
from gradio.components import State as C_State
from gradio.components import Textbox as C_Textbox
from gradio.components import Timeseries as C_Timeseries
from gradio.components import Variable as C_State
from gradio.components import Video as C_Video

View File

@ -63,6 +63,7 @@ def create_app() -> FastAPI:
allow_methods=["*"],
allow_headers=["*"],
)
app.state_holder = {}
@app.get("/user")
@app.get("/user/")
@ -204,8 +205,23 @@ def create_app() -> FastAPI:
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
if "session_hash" in body:
if body["session_hash"] not in app.state_holder:
app.state_holder[body["session_hash"]] = {
_id: getattr(block, "default_value", None)
for _id, block in app.blocks.blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body["session_hash"]]
else:
session_state = {}
try:
output = await run_in_threadpool(app.blocks.process_api, body, username)
output = await run_in_threadpool(
app.blocks.process_api,
body,
username,
session_state,
)
except BaseException as error:
if app.blocks.show_error:
traceback.print_exc()
@ -308,15 +324,13 @@ def get_types(cls_set: List[Type], component: str):
def get_state():
raise DeprecationWarning(
"This function is deprecated. To create stateful demos, pass 'state'"
"as both an input and output component. Please see the getting started"
"guide for more information."
"This function is deprecated. To create stateful demos, use the Variable"
" component. Please see the getting started for more information."
)
def set_state(*args):
raise DeprecationWarning(
"This function is deprecated. To create stateful demos, pass 'state'"
"as both an input and output component. Please see the getting started"
"guide for more information."
"This function is deprecated. To create stateful demos, use the Variable"
" component. Please see the getting started for more information."
)

View File

@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => {
};
export const fn = async (
session_hash: string,
api_endpoint: string,
action: string,
data: Record<string, unknown>,
queue: boolean,
queue_callback: (pos: number | null, is_initial?: boolean) => void
) => {
data["session_hash"] = session_hash;
if (queue && ["predict", "interpret"].includes(action)) {
data["action"] = action;
const output = await postData(api_endpoint + "queue/push/", data);

View File

@ -0,0 +1,2 @@
export { default as Component } from "./Variable.svelte";
export const modes = ["static"];

View File

@ -26,5 +26,6 @@ export const component_map: Record<string, any> = {
tabitem: () => import("./TabItem"),
textbox: () => import("./Textbox"),
timeseries: () => import("./TimeSeries"),
video: () => import("./Video")
video: () => import("./Video"),
variable: () => import("./Variable")
};

View File

@ -99,7 +99,8 @@ window.launchGradio = (config: Config, element_query: string) => {
config.dark = true;
target.classList.add("dark");
}
config.fn = fn.bind(null, config.root + "api/");
let session_hash = Math.random().toString(36).substring(2);
config.fn = fn.bind(null, session_hash, config.root + "api/");
new Blocks({
target: target,
props: config

View File

@ -26,7 +26,7 @@ function mock_api(page: Page, body: Array<unknown>) {
}
test("renders the correct elements", async ({ page }) => {
await mock_demo(page, "xray_blocks");
await mock_demo(page, "blocks_xray");
await page.goto("http://localhost:3000");
const description = await page.locator(".output-markdown");
@ -40,7 +40,7 @@ test("renders the correct elements", async ({ page }) => {
});
test("can run an api request and display the data", async ({ page }) => {
await mock_demo(page, "xray_blocks");
await mock_demo(page, "blocks_xray");
await mock_api(page, [
[
{