mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Fixes to gr.Examples
(#1952)
* still working * remove exampleset logic * wip * added tests and fix * fixed header issue * more tests * formatting * removed print * Update gradio/examples.py * used context managers * more readable * removed unnecessary start index * formatting
This commit is contained in:
parent
0cdb9b564c
commit
5fe02164f9
@ -4293,7 +4293,12 @@ class Dataset(Clickable, Component):
|
||||
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
|
||||
self.components = [get_component_instance(c, render=False) for c in components]
|
||||
self.type = type
|
||||
self.headers = headers or [c.label for c in self.components]
|
||||
if headers is not None:
|
||||
self.headers = headers
|
||||
elif all([c.label is None for c in self.components]):
|
||||
self.headers = []
|
||||
else:
|
||||
self.headers = [c.label or "" for c in self.components]
|
||||
self.samples = samples
|
||||
|
||||
def get_config(self):
|
||||
|
@ -6,10 +6,12 @@ from __future__ import annotations
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
from gradio import utils
|
||||
from gradio.components import Dataset
|
||||
from gradio.context import Context
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.flagging import CSVLogger
|
||||
|
||||
@ -18,6 +20,7 @@ if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio.components import Component
|
||||
|
||||
CACHED_FOLDER = "gradio_cached_examples"
|
||||
LOG_FILE = "log.csv"
|
||||
|
||||
set_documentation_group("component-helpers")
|
||||
|
||||
@ -45,7 +48,7 @@ class Examples:
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided.
|
||||
examples: example inputs that can be clicked to populate specific components. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
||||
inputs: the component or list of components corresponding to the examples
|
||||
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
|
||||
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
|
||||
@ -60,6 +63,8 @@ class Examples:
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
working_directory = Path().absolute()
|
||||
|
||||
if examples is None:
|
||||
raise ValueError("The parameter `examples` cannot be None")
|
||||
elif isinstance(examples, list) and (
|
||||
@ -75,34 +80,22 @@ class Examples:
|
||||
raise FileNotFoundError(
|
||||
"Could not find examples directory: " + examples
|
||||
)
|
||||
log_file = os.path.join(examples, "log.csv")
|
||||
if not os.path.exists(log_file):
|
||||
working_directory = examples
|
||||
if not os.path.exists(os.path.join(examples, LOG_FILE)):
|
||||
if len(inputs) == 1:
|
||||
exampleset = [
|
||||
[os.path.join(examples, item)] for item in os.listdir(examples)
|
||||
]
|
||||
examples = [[e] for e in os.listdir(examples)]
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
"Could not find log file (required for multiple inputs): "
|
||||
+ log_file
|
||||
+ LOG_FILE
|
||||
)
|
||||
else:
|
||||
with open(log_file) as logs:
|
||||
exampleset = list(csv.reader(logs))
|
||||
exampleset = exampleset[1:] # remove header
|
||||
for i, example in enumerate(exampleset):
|
||||
for j, (component, cell) in enumerate(
|
||||
zip(
|
||||
inputs + outputs,
|
||||
example,
|
||||
)
|
||||
):
|
||||
exampleset[i][j] = component.restore_flagged(
|
||||
examples,
|
||||
cell,
|
||||
None,
|
||||
)
|
||||
examples = exampleset
|
||||
with open(os.path.join(examples, LOG_FILE)) as logs:
|
||||
examples = list(csv.reader(logs))
|
||||
examples = [
|
||||
examples[i][: len(inputs)] for i in range(1, len(examples))
|
||||
] # remove header and unnecessary columns
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"The parameter `examples` must either be a directory or a nested "
|
||||
@ -135,21 +128,22 @@ class Examples:
|
||||
self.cache_examples = cache_examples
|
||||
self.examples_per_page = examples_per_page
|
||||
|
||||
dataset = Dataset(
|
||||
with utils.set_directory(working_directory):
|
||||
self.processed_examples = [
|
||||
[
|
||||
component.preprocess_example(sample)
|
||||
for component, sample in zip(inputs_with_examples, example)
|
||||
]
|
||||
for example in non_none_examples
|
||||
]
|
||||
|
||||
self.dataset = Dataset(
|
||||
components=inputs_with_examples,
|
||||
samples=non_none_examples,
|
||||
type="index",
|
||||
)
|
||||
|
||||
self.processed_examples = [
|
||||
[
|
||||
component.preprocess_example(sample)
|
||||
for component, sample in zip(inputs_with_examples, example)
|
||||
]
|
||||
for example in non_none_examples
|
||||
]
|
||||
|
||||
self.cached_folder = os.path.join(CACHED_FOLDER, str(dataset._id))
|
||||
self.cached_folder = os.path.join(CACHED_FOLDER, str(self.dataset._id))
|
||||
self.cached_file = os.path.join(self.cached_folder, "log.csv")
|
||||
if cache_examples:
|
||||
self.cache_interface_examples()
|
||||
@ -163,13 +157,14 @@ class Examples:
|
||||
processed_example = self.processed_examples[example_id]
|
||||
return utils.resolve_singleton(processed_example)
|
||||
|
||||
dataset.click(
|
||||
load_example,
|
||||
inputs=[dataset],
|
||||
outputs=inputs_with_examples + (outputs if cache_examples else []),
|
||||
_postprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
if Context.root_block:
|
||||
self.dataset.click(
|
||||
load_example,
|
||||
inputs=[self.dataset],
|
||||
outputs=inputs_with_examples + (outputs if cache_examples else []),
|
||||
_postprocess=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def cache_interface_examples(self) -> None:
|
||||
"""Caches all of the examples from an interface."""
|
||||
|
@ -11,9 +11,11 @@ import os
|
||||
import pkgutil
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from distutils.version import StrictVersion
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, NewType, Tuple, Type
|
||||
|
||||
import aiohttp
|
||||
@ -585,3 +587,14 @@ class Request:
|
||||
@property
|
||||
def status(self):
|
||||
return self._status
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_directory(path: Path):
|
||||
"""Context manager that sets the working directory to the given path."""
|
||||
origin = Path().absolute()
|
||||
try:
|
||||
os.chdir(path)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(origin)
|
||||
|
@ -1,19 +1,71 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from gradio import Interface, examples
|
||||
import gradio as gr
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestProcessExamples(unittest.TestCase):
|
||||
class TestExamples:
|
||||
def test_handle_single_input(self):
|
||||
examples = gr.Examples(["hello", "hi"], gr.Textbox())
|
||||
assert examples.processed_examples == [["hello"], ["hi"]]
|
||||
|
||||
examples = gr.Examples([["hello"]], gr.Textbox())
|
||||
assert examples.processed_examples == [["hello"]]
|
||||
|
||||
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
|
||||
assert examples.processed_examples == [[gr.media_data.BASE64_IMAGE]]
|
||||
|
||||
def test_handle_multiple_inputs(self):
|
||||
examples = gr.Examples(
|
||||
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
|
||||
)
|
||||
assert examples.processed_examples == [["hello", gr.media_data.BASE64_IMAGE]]
|
||||
|
||||
def test_handle_directory(self):
|
||||
examples = gr.Examples("test/test_files/images", gr.Image())
|
||||
assert examples.processed_examples == [
|
||||
[gr.media_data.BASE64_IMAGE],
|
||||
[gr.media_data.BASE64_IMAGE],
|
||||
]
|
||||
|
||||
def test_handle_directory_with_log_file(self):
|
||||
examples = gr.Examples(
|
||||
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
|
||||
)
|
||||
assert examples.processed_examples == [
|
||||
[gr.media_data.BASE64_IMAGE, "hello"],
|
||||
[gr.media_data.BASE64_IMAGE, "hi"],
|
||||
]
|
||||
|
||||
|
||||
class TestExamplesDataset:
|
||||
def test_no_headers(self):
|
||||
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
|
||||
assert examples.dataset.headers == []
|
||||
|
||||
def test_all_headers(self):
|
||||
examples = gr.Examples(
|
||||
"test/test_files/images_log",
|
||||
[gr.Image(label="im"), gr.Text(label="your text")],
|
||||
)
|
||||
assert examples.dataset.headers == ["im", "your text"]
|
||||
|
||||
def test_some_headers(self):
|
||||
examples = gr.Examples(
|
||||
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
|
||||
)
|
||||
assert examples.dataset.headers == ["im", ""]
|
||||
|
||||
|
||||
class TestProcessExamples:
|
||||
def test_process_example(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
|
||||
io = gr.Interface(lambda x: "Hello " + x, "text", "text", examples=[["World"]])
|
||||
prediction = io.examples_handler.process_example(0)
|
||||
self.assertEquals(prediction[0], "Hello World")
|
||||
assert prediction[0] == "Hello World"
|
||||
|
||||
def test_caching(self):
|
||||
io = Interface(
|
||||
io = gr.Interface(
|
||||
lambda x: "Hello " + x,
|
||||
"text",
|
||||
"text",
|
||||
@ -23,8 +75,4 @@ class TestProcessExamples(unittest.TestCase):
|
||||
io.examples_handler.cache_interface_examples()
|
||||
prediction = io.examples_handler.load_from_cache(1)
|
||||
io.close()
|
||||
self.assertEquals(prediction[0], "Hello Dunya")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
assert prediction[0] == "Hello Dunya"
|
||||
|
BIN
test/test_files/images/bus.png
Normal file
BIN
test/test_files/images/bus.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
BIN
test/test_files/images/bus_copy.png
Normal file
BIN
test/test_files/images/bus_copy.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
BIN
test/test_files/images_log/im/bus.png
Normal file
BIN
test/test_files/images_log/im/bus.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
BIN
test/test_files/images_log/im/bus_copy.png
Normal file
BIN
test/test_files/images_log/im/bus_copy.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.9 KiB |
3
test/test_files/images_log/log.csv
Normal file
3
test/test_files/images_log/log.csv
Normal file
@ -0,0 +1,3 @@
|
||||
im,text
|
||||
im/bus.png,hello
|
||||
im/bus_copy.png,hi
|
|
Loading…
Reference in New Issue
Block a user