mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
added typing; deprecated capture_session; fixed test
This commit is contained in:
parent
40adb8537b
commit
bd433577b6
@ -3,10 +3,13 @@ This is the core file in the `gradio` package, and defines the Interface class,
|
||||
interface using the input and output types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import csv
|
||||
import getpass
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Callable
|
||||
import markdown2
|
||||
import numpy as np
|
||||
import os
|
||||
@ -14,6 +17,7 @@ import pkg_resources
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from typing import Union, Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
import warnings
|
||||
import webbrowser
|
||||
import weakref
|
||||
@ -21,9 +25,13 @@ import weakref
|
||||
from gradio import networking, strings, utils, encryptor, queue
|
||||
from gradio.external import load_interface, load_from_pipeline
|
||||
from gradio.flagging import FlaggingCallback, CSVLogger
|
||||
from gradio.inputs import get_input_instance
|
||||
from gradio.inputs import get_input_instance, InputComponent
|
||||
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio.outputs import get_output_instance, OutputComponent
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (always False at runtime).
|
||||
import transformers
|
||||
import flask
|
||||
|
||||
|
||||
class Interface:
|
||||
@ -44,7 +52,12 @@ class Interface:
|
||||
return list(Interface.instances)
|
||||
|
||||
@classmethod
|
||||
def load(cls, name: str, src: str = None, api_key: str = None, alias: str = None, **kwargs):
|
||||
def load(cls,
|
||||
name: str,
|
||||
src: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
alias: Optional[str] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Class method to construct an Interface from an external source repository, such as huggingface.
|
||||
Parameters:
|
||||
@ -64,7 +77,10 @@ class Interface:
|
||||
return interface
|
||||
|
||||
@classmethod
|
||||
def from_pipeline(cls, pipeline: "transformers.Pipeline", **kwargs):
|
||||
def from_pipeline(
|
||||
cls,
|
||||
pipeline: transformers.Pipeline,
|
||||
**kwargs):
|
||||
"""
|
||||
Class method to construct an Interface from a Hugging Face transformers.Pipeline.
|
||||
pipeline (transformers.Pipeline):
|
||||
@ -76,18 +92,46 @@ class Interface:
|
||||
interface = cls(**kwargs)
|
||||
return interface
|
||||
|
||||
def __init__(self, fn, inputs=None, outputs=None, verbose=None, examples=None,
|
||||
examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True,
|
||||
capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, height=None, width=None, allow_screenshot=True, allow_flagging=None, flagging_options=None,
|
||||
encrypt=None, show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None,
|
||||
flagging_callback=CSVLogger()):
|
||||
def __init__(
|
||||
self,
|
||||
fn: Callable,
|
||||
inputs: str | InputComponent | List[str | InputComponent] = None,
|
||||
outputs: str | OutputComponent | List[str | OutputComponent] = None,
|
||||
verbose: bool = None,
|
||||
examples: Union[List[List[Any]], str] = None,
|
||||
examples_per_page: int = 10,
|
||||
live: bool = False,
|
||||
layout: str = "unaligned",
|
||||
show_input: bool = True,
|
||||
show_output: bool = True,
|
||||
capture_session: Optional[bool] = None,
|
||||
interpretation: Optional[bool] = None,
|
||||
num_shap: int = 2.0,
|
||||
theme: Optional[str] = None,
|
||||
repeat_outputs_per_model: bool = True,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
article: Optional[str] = None,
|
||||
thumbnail: Optional[str] = None,
|
||||
css: Optional[str] = None,
|
||||
height=None,
|
||||
width=None,
|
||||
allow_screenshot: bool = True,
|
||||
allow_flagging: bool = None,
|
||||
flagging_options: List[str]=None,
|
||||
encrypt=None,
|
||||
show_tips=None,
|
||||
flagging_dir: str = "flagged",
|
||||
analytics_enabled: Optional[bool] = None,
|
||||
enable_queue=None,
|
||||
api_mode=None,
|
||||
flagging_callback: FlaggingCallback = CSVLogger()
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
fn (Callable): the function to wrap an interface around.
|
||||
inputs (Union[str, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn.
|
||||
outputs (Union[str, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn.
|
||||
inputs (Union[str, InputComponent, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn.
|
||||
outputs (Union[str, OutputComponent, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn.
|
||||
verbose (bool): DEPRECATED. Whether to print detailed information during launch.
|
||||
examples (Union[List[List[Any]], str]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
|
||||
examples_per_page (int): If examples are provided, how many to display per page.
|
||||
@ -343,7 +387,12 @@ class Interface:
|
||||
config["examples"] = self.examples
|
||||
return config
|
||||
|
||||
def run_prediction(self, processed_input, return_duration=False, called_directly=False):
|
||||
def run_prediction(
|
||||
self,
|
||||
processed_input: List[Any],
|
||||
return_duration: bool = False,
|
||||
called_directly: bool = False
|
||||
) -> List[Any] | Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
This is the method that actually runs the prediction function with the given (processed) inputs.
|
||||
Parameters:
|
||||
@ -384,7 +433,10 @@ class Interface:
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process(self, raw_input):
|
||||
def process(
|
||||
self,
|
||||
raw_input: List[Any]
|
||||
) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
Parameters:
|
||||
raw_input: a list of raw inputs to process and apply the prediction(s) on.
|
||||
@ -400,7 +452,10 @@ class Interface:
|
||||
for i, output_component in enumerate(self.output_components)]
|
||||
return processed_output, durations
|
||||
|
||||
def interpret(self, raw_input):
|
||||
def interpret(
|
||||
self,
|
||||
raw_input: List[Any]
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
||||
interpretation for a certain set of UI component types, as well as the custom interpretation case.
|
||||
@ -515,7 +570,11 @@ class Interface:
|
||||
interpretation = [interpretation]
|
||||
return interpretation, []
|
||||
|
||||
def run_until_interrupted(self, thread, path_to_local_server):
|
||||
def run_until_interrupted(
|
||||
self,
|
||||
thread: threading.Thread,
|
||||
path_to_local_server: str
|
||||
):
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.5)
|
||||
@ -543,11 +602,25 @@ class Interface:
|
||||
print("PASSED")
|
||||
continue
|
||||
|
||||
def launch(self, inline=None, inbrowser=None, share=False, debug=False,
|
||||
auth=None, auth_message=None, private_endpoint=None,
|
||||
prevent_thread_lock=False, show_error=True, server_name=None,
|
||||
server_port=None, show_tips=False, enable_queue=False,
|
||||
height=500, width=900, encrypt=False):
|
||||
def launch(
|
||||
self,
|
||||
inline: bool = None,
|
||||
inbrowser: bool = None,
|
||||
share: bool = False,
|
||||
debug: bool = False,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
private_endpoint: Optional[str] = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = True,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
show_tips: bool = False,
|
||||
enable_queue: bool = False,
|
||||
height: int = 500,
|
||||
width: int = 900,
|
||||
encrypt: bool = False
|
||||
) -> Tuple[flask.Flask, str, str]:
|
||||
"""
|
||||
Launches the webserver that serves the UI for the interface.
|
||||
Parameters:
|
||||
@ -564,13 +637,13 @@ class Interface:
|
||||
server_name (str): to make app accessible on local network, set this to "0.0.0.0".
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
width (int): The width in pixels of the <iframe> element containing the interface (used if inline=True)
|
||||
height (int): The height in pixels of the <iframe> element containing the interface (used if inline=True)
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
Returns:
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
width (bool): The width of the <iframe> element containing the interface (used if inline=True)
|
||||
height (bool): The height of the <iframe> element containing the interface (used if inline=True)
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
"""
|
||||
# Set up local flask server
|
||||
config = self.get_config_file()
|
||||
@ -702,7 +775,10 @@ class Interface:
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def close(self, verbose=True):
|
||||
def close(
|
||||
self,
|
||||
verbose: bool = True
|
||||
):
|
||||
"""
|
||||
Closes the Interface that was launched. This will close the server and free the port.
|
||||
"""
|
||||
|
@ -189,15 +189,5 @@ class TestInterface(unittest.TestCase):
|
||||
interface.integrate(mlflow=mlflow)
|
||||
mock_post.assert_called_once()
|
||||
|
||||
def test_capture_session(self):
|
||||
interface = Interface(lambda x: x, "textbox", "label", capture_session=True, interpretation=lambda x: 0)
|
||||
interface.session = (mock.MagicMock(), mock.MagicMock())
|
||||
interface.interpret(["quickest brown fox"])
|
||||
interface.session[0].as_default.assert_called_once()
|
||||
interface.session[1].as_default.assert_called_once()
|
||||
interface.run_prediction(["quickest brown fox"])
|
||||
self.assertEqual(interface.session[0].as_default.call_count, 2)
|
||||
self.assertEqual(interface.session[1].as_default.call_count, 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user