Fixes to gr.Examples (#1952)

* still working

* remove exampleset logic

* wip

* added tests and fix

* fixed header issue

* more tests

* formatting

* removed print

* Update gradio/examples.py

* used context managers

* more readable

* removed unnecessary start index

* formatting
This commit is contained in:
Abubakar Abid 2022-08-08 10:35:26 -07:00 committed by GitHub
parent 0cdb9b564c
commit 5fe02164f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 116 additions and 52 deletions

View File

@ -4293,7 +4293,12 @@ class Dataset(Clickable, Component):
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
self.components = [get_component_instance(c, render=False) for c in components]
self.type = type
self.headers = headers or [c.label for c in self.components]
if headers is not None:
self.headers = headers
elif all([c.label is None for c in self.components]):
self.headers = []
else:
self.headers = [c.label or "" for c in self.components]
self.samples = samples
def get_config(self):

View File

@ -6,10 +6,12 @@ from __future__ import annotations
import csv
import os
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from gradio import utils
from gradio.components import Dataset
from gradio.context import Context
from gradio.documentation import document, set_documentation_group
from gradio.flagging import CSVLogger
@ -18,6 +20,7 @@ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component
CACHED_FOLDER = "gradio_cached_examples"
LOG_FILE = "log.csv"
set_documentation_group("component-helpers")
@ -45,7 +48,7 @@ class Examples:
):
"""
Parameters:
examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided.
examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
inputs: the component or list of components corresponding to the examples
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
@ -60,6 +63,8 @@ class Examples:
if not isinstance(outputs, list):
outputs = [outputs]
working_directory = Path().absolute()
if examples is None:
raise ValueError("The parameter `examples` cannot be None")
elif isinstance(examples, list) and (
@ -75,34 +80,22 @@ class Examples:
raise FileNotFoundError(
"Could not find examples directory: " + examples
)
log_file = os.path.join(examples, "log.csv")
if not os.path.exists(log_file):
working_directory = examples
if not os.path.exists(os.path.join(examples, LOG_FILE)):
if len(inputs) == 1:
exampleset = [
[os.path.join(examples, item)] for item in os.listdir(examples)
]
examples = [[e] for e in os.listdir(examples)]
else:
raise FileNotFoundError(
"Could not find log file (required for multiple inputs): "
+ log_file
+ LOG_FILE
)
else:
with open(log_file) as logs:
exampleset = list(csv.reader(logs))
exampleset = exampleset[1:] # remove header
for i, example in enumerate(exampleset):
for j, (component, cell) in enumerate(
zip(
inputs + outputs,
example,
)
):
exampleset[i][j] = component.restore_flagged(
examples,
cell,
None,
)
examples = exampleset
with open(os.path.join(examples, LOG_FILE)) as logs:
examples = list(csv.reader(logs))
examples = [
examples[i][: len(inputs)] for i in range(1, len(examples))
] # remove header and unnecessary columns
else:
raise ValueError(
"The parameter `examples` must either be a directory or a nested "
@ -135,21 +128,22 @@ class Examples:
self.cache_examples = cache_examples
self.examples_per_page = examples_per_page
dataset = Dataset(
with utils.set_directory(working_directory):
self.processed_examples = [
[
component.preprocess_example(sample)
for component, sample in zip(inputs_with_examples, example)
]
for example in non_none_examples
]
self.dataset = Dataset(
components=inputs_with_examples,
samples=non_none_examples,
type="index",
)
self.processed_examples = [
[
component.preprocess_example(sample)
for component, sample in zip(inputs_with_examples, example)
]
for example in non_none_examples
]
self.cached_folder = os.path.join(CACHED_FOLDER, str(dataset._id))
self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id))
self.cached_file = os.path.join(self.cached_folder, "log.csv")
if cache_examples:
self.cache_interface_examples()
@ -163,13 +157,14 @@ class Examples:
processed_example = self.processed_examples[example_id]
return utils.resolve_singleton(processed_example)
dataset.click(
load_example,
inputs=[dataset],
outputs=inputs_with_examples + (outputs if cache_examples else []),
_postprocess=False,
queue=False,
)
if Context.root_block:
self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=inputs_with_examples + (outputs if cache_examples else []),
_postprocess=False,
queue=False,
)
def cache_interface_examples(self) -> None:
"""Caches all of the examples from an interface."""

View File

@ -11,9 +11,11 @@ import os
import pkgutil
import random
import warnings
from contextlib import contextmanager
from copy import deepcopy
from distutils.version import StrictVersion
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, NewType, Tuple, Type
import aiohttp
@ -585,3 +587,14 @@ class Request:
@property
def status(self):
return self._status
@contextmanager
def set_directory(path: Path):
"""Context manager that sets the working directory to the given path."""
origin = Path().absolute()
try:
os.chdir(path)
yield
finally:
os.chdir(origin)

View File

@ -1,19 +1,71 @@
import os
import unittest
from gradio import Interface, examples
import gradio as gr
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestProcessExamples(unittest.TestCase):
class TestExamples:
def test_handle_single_input(self):
examples = gr.Examples(["hello", "hi"], gr.Textbox())
assert examples.processed_examples == [["hello"], ["hi"]]
examples = gr.Examples([["hello"]], gr.Textbox())
assert examples.processed_examples == [["hello"]]
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
assert examples.processed_examples == [[gr.media_data.BASE64_IMAGE]]
def test_handle_multiple_inputs(self):
examples = gr.Examples(
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
)
assert examples.processed_examples == [["hello", gr.media_data.BASE64_IMAGE]]
def test_handle_directory(self):
examples = gr.Examples("test/test_files/images", gr.Image())
assert examples.processed_examples == [
[gr.media_data.BASE64_IMAGE],
[gr.media_data.BASE64_IMAGE],
]
def test_handle_directory_with_log_file(self):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.processed_examples == [
[gr.media_data.BASE64_IMAGE, "hello"],
[gr.media_data.BASE64_IMAGE, "hi"],
]
class TestExamplesDataset:
def test_no_headers(self):
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
assert examples.dataset.headers == []
def test_all_headers(self):
examples = gr.Examples(
"test/test_files/images_log",
[gr.Image(label="im"), gr.Text(label="your text")],
)
assert examples.dataset.headers == ["im", "your text"]
def test_some_headers(self):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.dataset.headers == ["im", ""]
class TestProcessExamples:
def test_process_example(self):
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
io = gr.Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
prediction = io.examples_handler.process_example(0)
self.assertEquals(prediction[0], "Hello World")
assert prediction[0] == "Hello World"
def test_caching(self):
io = Interface(
io = gr.Interface(
lambda x: "Hello " + x,
"text",
"text",
@ -23,8 +75,4 @@ class TestProcessExamples(unittest.TestCase):
io.examples_handler.cache_interface_examples()
prediction = io.examples_handler.load_from_cache(1)
io.close()
self.assertEquals(prediction[0], "Hello Dunya")
if __name__ == "__main__":
unittest.main()
assert prediction[0] == "Hello Dunya"

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@ -0,0 +1,3 @@
im,text
im/bus.png,hello
im/bus_copy.png,hi
1 im text
2 im/bus.png hello
3 im/bus_copy.png hi