mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
added tests for process examples
This commit is contained in:
parent
9c478e2d50
commit
7f23f9b326
@ -1,15 +1,27 @@
|
||||
"""
|
||||
Defines helper methods useful for loading and caching Interface examples.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from gradio.flagging import CSVLogger
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio import Interface
|
||||
|
||||
CACHED_FOLDER = "gradio_cached_examples"
|
||||
CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv")
|
||||
|
||||
|
||||
def process_example(interface, example_id: int):
|
||||
def process_example(
|
||||
interface: Interface,
|
||||
example_id: int
|
||||
) -> Tuple[List[Any], List[float]]:
|
||||
"""Loads an example from the interface and returns its prediction."""
|
||||
example_set = interface.examples[example_id]
|
||||
raw_input = [
|
||||
interface.input_components[i].preprocess_example(example)
|
||||
@ -19,7 +31,10 @@ def process_example(interface, example_id: int):
|
||||
return prediction, durations
|
||||
|
||||
|
||||
def cache_interface_examples(interface) -> None:
|
||||
def cache_interface_examples(
|
||||
interface: Interface
|
||||
) -> None:
|
||||
"""Caches all of the examples from an interface."""
|
||||
if os.path.exists(CACHE_FILE):
|
||||
print(
|
||||
f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache."
|
||||
@ -39,7 +54,11 @@ def cache_interface_examples(interface) -> None:
|
||||
raise e
|
||||
|
||||
|
||||
def load_from_cache(interface, example_id: int) -> List[Any]:
|
||||
def load_from_cache(
|
||||
interface: Interface,
|
||||
example_id: int
|
||||
) -> List[Any]:
|
||||
"""Loads a particular cached example for the interface."""
|
||||
with open(CACHE_FILE) as cache:
|
||||
examples = list(csv.reader(cache))
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
|
27
test/test_process_examples.py
Normal file
27
test/test_process_examples.py
Normal file
@ -0,0 +1,27 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from gradio import Interface, process_examples
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestProcessExamples(unittest.TestCase):
|
||||
def test_process_example(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text",
|
||||
examples=[["World"]])
|
||||
prediction, _ = process_examples.process_example(io, 0)
|
||||
self.assertEquals(prediction[0], "Hello World")
|
||||
|
||||
def test_caching(self):
|
||||
io = Interface(lambda x: "Hello " + x, "text", "text",
|
||||
examples=[["World"], ["Dunya"], ["Monde"]])
|
||||
io.launch(prevent_thread_lock=True)
|
||||
process_examples.cache_interface_examples(io)
|
||||
prediction = process_examples.load_from_cache(io, 1)
|
||||
io.close()
|
||||
self.assertEquals(prediction[0], "Hello Dunya")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user