added tests for process examples

This commit is contained in:
Abubakar Abid 2022-01-26 00:05:50 -06:00
parent 9c478e2d50
commit 7f23f9b326
2 changed files with 50 additions and 4 deletions

View File

@ -1,15 +1,27 @@
"""
Defines helper methods useful for loading and caching Interface examples.
"""
from __future__ import annotations
import csv
import os
import shutil
from typing import Any, List
from typing import Any, List, Tuple, TYPE_CHECKING
from gradio.flagging import CSVLogger
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio import Interface
CACHED_FOLDER = "gradio_cached_examples"
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
def process_example(interface, example_id: int):
def process_example(
interface: Interface,
example_id: int
) -> Tuple[List[Any], List[float]]:
"""Loads an example from the interface and returns its prediction."""
example_set = interface.examples[example_id]
raw_input = [
interface.input_components[i].preprocess_example(example)
@ -19,7 +31,10 @@ def process_example(interface, example_id: int):
return prediction, durations
def cache_interface_examples(interface) -> None:
def cache_interface_examples(
interface: Interface
) -> None:
"""Caches all of the examples from an interface."""
if os.path.exists(CACHE_FILE):
print(
f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache."
@ -39,7 +54,11 @@ def cache_interface_examples(interface) -> None:
raise e
def load_from_cache(interface, example_id: int) -> List[Any]:
def load_from_cache(
interface: Interface,
example_id: int
) -> List[Any]:
"""Loads a particular cached example for the interface."""
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header

View File

@ -0,0 +1,27 @@
import os
import unittest
from gradio import Interface, process_examples
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestProcessExamples(unittest.TestCase):
def test_process_example(self):
io = Interface(lambda x: "Hello " + x, "text", "text",
examples=[["World"]])
prediction, _ = process_examples.process_example(io, 0)
self.assertEquals(prediction[0], "Hello World")
def test_caching(self):
io = Interface(lambda x: "Hello " + x, "text", "text",
examples=[["World"], ["Dunya"], ["Monde"]])
io.launch(prevent_thread_lock=True)
process_examples.cache_interface_examples(io)
prediction = process_examples.load_from_cache(io, 1)
io.close()
self.assertEquals(prediction[0], "Hello Dunya")
if __name__ == "__main__":
unittest.main()