Quick fix for partial functions (#2235)

* fix for partial

* added test

* fixed tests by adding self
This commit is contained in:
Abubakar Abid 2022-09-12 17:58:17 -07:00 committed by GitHub
parent fb434fc73d
commit 2942e24aa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 66 deletions

View File

@ -248,7 +248,9 @@ class BlockFunction:
def __str__(self):
return str(
{
"fn": self.fn.__name__ if self.fn is not None else None,
"fn": getattr(self.fn, "__name__", "fn")
if self.fn is not None
else None,
"preprocess": self.preprocess,
"postprocess": self.postprocess,
}

View File

@ -288,7 +288,7 @@ class Interface(Blocks):
self.api_mode = _api_mode
self.fn = fn
self.fn_durations = [0, 0]
self.__name__ = fn.__name__
self.__name__ = getattr(fn, "__name__", "fn")
self.live = live
self.title = title

View File

@ -7,6 +7,8 @@ import time
import unittest
import unittest.mock as mock
from contextlib import contextmanager
from functools import partial
from string import capwords
from unittest.mock import patch
import mlflow
@ -124,6 +126,19 @@ class TestBlocks(unittest.TestCase):
config2 = demo2.get_config_file()
self.assertTrue(assert_configs_are_equivalent_besides_ids(config1, config2))
def test_partial_fn_in_config(self):
def greet(name, formatter):
return formatter("Hello " + name + "!")
greet_upper_case = partial(greet, formatter=capwords)
with gr.Blocks() as demo:
t = gr.Textbox()
o = gr.Textbox()
t.change(greet_upper_case, t, o)
assert len(demo.fns) == 1
assert "fn" in str(demo.fns[0])
@pytest.mark.asyncio
async def test_async_function(self):
async def wait():
@ -206,77 +221,84 @@ class TestBlocks(unittest.TestCase):
mock_post.assert_called_once()
def test_slider_random_value_config():
with gr.Blocks() as demo:
gr.Slider(
value=11.2, minimum=-10.2, maximum=15, label="Non-random Slider (Static)"
)
gr.Slider(
randomize=True, minimum=100, maximum=200, label="Random Slider (Input 1)"
)
gr.Slider(
randomize=True, minimum=10, maximum=23.2, label="Random Slider (Input 2)"
)
for component in demo.blocks.values():
if isinstance(component, gr.components.IOComponent):
if "Non-random" in component.label:
assert not component.attach_load_event
else:
assert component.attach_load_event
dependencies_on_load = [
dep["trigger"] == "load" for dep in demo.config["dependencies"]
]
assert all(dependencies_on_load)
assert len(dependencies_on_load) == 2
assert not any([dep["queue"] for dep in demo.config["dependencies"]])
class TestComponentsInBlocks:
def test_slider_random_value_config(self):
with gr.Blocks() as demo:
gr.Slider(
value=11.2,
minimum=-10.2,
maximum=15,
label="Non-random Slider (Static)",
)
gr.Slider(
randomize=True,
minimum=100,
maximum=200,
label="Random Slider (Input 1)",
)
gr.Slider(
randomize=True,
minimum=10,
maximum=23.2,
label="Random Slider (Input 2)",
)
for component in demo.blocks.values():
if isinstance(component, gr.components.IOComponent):
if "Non-random" in component.label:
assert not component.attach_load_event
else:
assert component.attach_load_event
dependencies_on_load = [
dep["trigger"] == "load" for dep in demo.config["dependencies"]
]
assert all(dependencies_on_load)
assert len(dependencies_on_load) == 2
assert not any([dep["queue"] for dep in demo.config["dependencies"]])
def test_io_components_attach_load_events_when_value_is_fn(io_components):
io_components = [comp for comp in io_components if not (comp == gr.State)]
interface = gr.Interface(
lambda *args: None,
inputs=[comp(value=lambda: None) for comp in io_components],
outputs=None,
)
dependencies_on_load = [
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
]
assert len(dependencies_on_load) == len(io_components)
def test_blocks_do_not_filter_none_values_from_updates(io_components):
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
with gr.Blocks() as demo:
for component in io_components:
component.render()
btn = gr.Button(value="Reset")
btn.click(
lambda: [gr.update(value=None) for _ in io_components],
inputs=[],
outputs=io_components,
def test_io_components_attach_load_events_when_value_is_fn(self, io_components):
io_components = [comp for comp in io_components if not (comp == gr.State)]
interface = gr.Interface(
lambda *args: None,
inputs=[comp(value=lambda: None) for comp in io_components],
outputs=None,
)
output = demo.postprocess_data(
0, [gr.update(value=None) for _ in io_components], state=None
)
assert all(
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
)
dependencies_on_load = [
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
]
assert len(dependencies_on_load) == len(io_components)
def test_blocks_do_not_filter_none_values_from_updates(self, io_components):
io_components = [c() for c in io_components if c not in [gr.State, gr.Button]]
with gr.Blocks() as demo:
for component in io_components:
component.render()
btn = gr.Button(value="Reset")
btn.click(
lambda: [gr.update(value=None) for _ in io_components],
inputs=[],
outputs=io_components,
)
def test_blocks_does_not_replace_keyword_literal():
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button(value="Reset")
btn.click(
lambda: gr.update(value="NO_VALUE"),
inputs=[],
outputs=text,
output = demo.postprocess_data(
0, [gr.update(value=None) for _ in io_components], state=None
)
assert all(
[o["value"] == c.postprocess(None) for o, c in zip(output, io_components)]
)
output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None)
assert output[0]["value"] == "NO_VALUE"
def test_blocks_does_not_replace_keyword_literal(self):
with gr.Blocks() as demo:
text = gr.Textbox()
btn = gr.Button(value="Reset")
btn.click(
lambda: gr.update(value="NO_VALUE"),
inputs=[],
outputs=text,
)
output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None)
assert output[0]["value"] == "NO_VALUE"
class TestCallFunction:

View File

@ -3,6 +3,8 @@ import sys
import unittest
import unittest.mock as mock
from contextlib import contextmanager
from functools import partial
from string import capwords
import mlflow
import pytest
@ -49,6 +51,14 @@ class TestInterface(unittest.TestCase):
with self.assertRaises(TypeError):
Interface(lambda x: x, examples=1234)
def test_partial_functions(self):
def greet(name, formatter):
return formatter("Hello " + name + "!")
greet_upper_case = partial(greet, formatter=capwords)
demo = Interface(fn=greet_upper_case, inputs="text", outputs="text")
assert demo("abubakar") == "Hello Abubakar!"
def test_examples_valid_path(self):
path = os.path.join(
os.path.dirname(__file__), "../gradio/test_data/flagged_with_log"