mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Quick fix for partial functions (#2235)
* fix for partial * added test * fixed tests by adding self
This commit is contained in:
parent
fb434fc73d
commit
2942e24aa3
@ -248,7 +248,9 @@ class BlockFunction:
|
||||
def __str__(self):
|
||||
return str(
|
||||
{
|
||||
"fn": self.fn.__name__ if self.fn is not None else None,
|
||||
"fn": getattr(self.fn, "__name__", "fn")
|
||||
if self.fn is not None
|
||||
else None,
|
||||
"preprocess": self.preprocess,
|
||||
"postprocess": self.postprocess,
|
||||
}
|
||||
|
@ -288,7 +288,7 @@ class Interface(Blocks):
|
||||
self.api_mode = _api_mode
|
||||
self.fn = fn
|
||||
self.fn_durations = [0, 0]
|
||||
self.__name__ = fn.__name__
|
||||
self.__name__ = getattr(fn, "__name__", "fn")
|
||||
self.live = live
|
||||
self.title = title
|
||||
|
||||
|
@ -7,6 +7,8 @@ import time
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from string import capwords
|
||||
from unittest.mock import patch
|
||||
|
||||
import mlflow
|
||||
@ -124,6 +126,19 @@ class TestBlocks(unittest.TestCase):
|
||||
config2 = demo2.get_config_file()
|
||||
self.assertTrue(assert_configs_are_equivalent_besides_ids(config1, config2))
|
||||
|
||||
def test_partial_fn_in_config(self):
|
||||
def greet(name, formatter):
|
||||
return formatter("Hello " + name + "!")
|
||||
|
||||
greet_upper_case = partial(greet, formatter=capwords)
|
||||
with gr.Blocks() as demo:
|
||||
t = gr.Textbox()
|
||||
o = gr.Textbox()
|
||||
t.change(greet_upper_case, t, o)
|
||||
|
||||
assert len(demo.fns) == 1
|
||||
assert "fn" in str(demo.fns[0])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function(self):
|
||||
async def wait():
|
||||
@ -206,77 +221,84 @@ class TestBlocks(unittest.TestCase):
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
def test_slider_random_value_config():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Slider(
|
||||
value=11.2, minimum=-10.2, maximum=15, label="Non-random Slider (Static)"
|
||||
)
|
||||
gr.Slider(
|
||||
randomize=True, minimum=100, maximum=200, label="Random Slider (Input 1)"
|
||||
)
|
||||
gr.Slider(
|
||||
randomize=True, minimum=10, maximum=23.2, label="Random Slider (Input 2)"
|
||||
)
|
||||
for component in demo.blocks.values():
|
||||
if isinstance(component, gr.components.IOComponent):
|
||||
if "Non-random" in component.label:
|
||||
assert not component.attach_load_event
|
||||
else:
|
||||
assert component.attach_load_event
|
||||
dependencies_on_load = [
|
||||
dep["trigger"] == "load" for dep in demo.config["dependencies"]
|
||||
]
|
||||
assert all(dependencies_on_load)
|
||||
assert len(dependencies_on_load) == 2
|
||||
assert not any([dep["queue"] for dep in demo.config["dependencies"]])
|
||||
class TestComponentsInBlocks:
|
||||
def test_slider_random_value_config(self):
|
||||
with gr.Blocks() as demo:
|
||||
gr.Slider(
|
||||
value=11.2,
|
||||
minimum=-10.2,
|
||||
maximum=15,
|
||||
label="Non-random Slider (Static)",
|
||||
)
|
||||
gr.Slider(
|
||||
randomize=True,
|
||||
minimum=100,
|
||||
maximum=200,
|
||||
label="Random Slider (Input 1)",
|
||||
)
|
||||
gr.Slider(
|
||||
randomize=True,
|
||||
minimum=10,
|
||||
maximum=23.2,
|
||||
label="Random Slider (Input 2)",
|
||||
)
|
||||
for component in demo.blocks.values():
|
||||
if isinstance(component, gr.components.IOComponent):
|
||||
if "Non-random" in component.label:
|
||||
assert not component.attach_load_event
|
||||
else:
|
||||
assert component.attach_load_event
|
||||
dependencies_on_load = [
|
||||
dep["trigger"] == "load" for dep in demo.config["dependencies"]
|
||||
]
|
||||
assert all(dependencies_on_load)
|
||||
assert len(dependencies_on_load) == 2
|
||||
assert not any([dep["queue"] for dep in demo.config["dependencies"]])
|
||||
|
||||
|
||||
def test_io_components_attach_load_events_when_value_is_fn(io_components):
|
||||
io_components = [comp for comp in io_components if not (comp == gr.State)]
|
||||
interface = gr.Interface(
|
||||
lambda *args: None,
|
||||
inputs=[comp(value=lambda: None) for comp in io_components],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
dependencies_on_load = [
|
||||
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
|
||||
]
|
||||
assert len(dependencies_on_load) == len(io_components)
|
||||
|
||||
|
||||
def test_blocks_do_not_filter_none_values_from_updates(io_components):
|
||||
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
component.render()
|
||||
btn = gr.Button(value="Reset")
|
||||
btn.click(
|
||||
lambda: [gr.update(value=None) for _ in io_components],
|
||||
inputs=[],
|
||||
outputs=io_components,
|
||||
def test_io_components_attach_load_events_when_value_is_fn(self, io_components):
|
||||
io_components = [comp for comp in io_components if not (comp == gr.State)]
|
||||
interface = gr.Interface(
|
||||
lambda *args: None,
|
||||
inputs=[comp(value=lambda: None) for comp in io_components],
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
output = demo.postprocess_data(
|
||||
0, [gr.update(value=None) for _ in io_components], state=None
|
||||
)
|
||||
assert all(
|
||||
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
|
||||
)
|
||||
dependencies_on_load = [
|
||||
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
|
||||
]
|
||||
assert len(dependencies_on_load) == len(io_components)
|
||||
|
||||
def test_blocks_do_not_filter_none_values_from_updates(self, io_components):
|
||||
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
component.render()
|
||||
btn = gr.Button(value="Reset")
|
||||
btn.click(
|
||||
lambda: [gr.update(value=None) for _ in io_components],
|
||||
inputs=[],
|
||||
outputs=io_components,
|
||||
)
|
||||
|
||||
def test_blocks_does_not_replace_keyword_literal():
|
||||
with gr.Blocks() as demo:
|
||||
text = gr.Textbox()
|
||||
btn = gr.Button(value="Reset")
|
||||
btn.click(
|
||||
lambda: gr.update(value="NO_VALUE"),
|
||||
inputs=[],
|
||||
outputs=text,
|
||||
output = demo.postprocess_data(
|
||||
0, [gr.update(value=None) for _ in io_components], state=None
|
||||
)
|
||||
assert all(
|
||||
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
|
||||
)
|
||||
|
||||
output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None)
|
||||
assert output[0]["value"] == "NO_VALUE"
|
||||
def test_blocks_does_not_replace_keyword_literal(self):
|
||||
with gr.Blocks() as demo:
|
||||
text = gr.Textbox()
|
||||
btn = gr.Button(value="Reset")
|
||||
btn.click(
|
||||
lambda: gr.update(value="NO_VALUE"),
|
||||
inputs=[],
|
||||
outputs=text,
|
||||
)
|
||||
|
||||
output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None)
|
||||
assert output[0]["value"] == "NO_VALUE"
|
||||
|
||||
|
||||
class TestCallFunction:
|
||||
|
@ -3,6 +3,8 @@ import sys
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from string import capwords
|
||||
|
||||
import mlflow
|
||||
import pytest
|
||||
@ -49,6 +51,14 @@ class TestInterface(unittest.TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
Interface(lambda x: x, examples=1234)
|
||||
|
||||
def test_partial_functions(self):
|
||||
def greet(name, formatter):
|
||||
return formatter("Hello " + name + "!")
|
||||
|
||||
greet_upper_case = partial(greet, formatter=capwords)
|
||||
demo = Interface(fn=greet_upper_case, inputs="text", outputs="text")
|
||||
assert demo("abubakar") == "Hello Abubakar!"
|
||||
|
||||
def test_examples_valid_path(self):
|
||||
path = os.path.join(
|
||||
os.path.dirname(__file__), "../gradio/test_data/flagged_with_log"
|
||||
|
Loading…
x
Reference in New Issue
Block a user