mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
Merge branch 'abidlabs/external' of https://github.com/gradio-app/gradio into abidlabs/external
This commit is contained in:
commit
1c14d2d2bd
@ -1,6 +1,4 @@
|
||||
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
from gradio.external import load_interface
|
||||
from gradio.compound import Parallel, Sequential
|
||||
import pkg_resources
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
|
@ -1,64 +0,0 @@
|
||||
from gradio.interface import Interface
|
||||
|
||||
class Parallel:
|
||||
def __init__(self, interfaces):
|
||||
fns = []
|
||||
for io in interfaces:
|
||||
fns.extend(io.predict)
|
||||
|
||||
self.interfaces = interfaces
|
||||
self.predict = fns
|
||||
self.compound_interface = Interface(
|
||||
fn=fns, inputs=interfaces[0].input_interfaces, outputs=interfaces[0].output_interfaces)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
repr = "Gradio Parallel Interface, consisting of:"
|
||||
repr += "\n-----------------------------------------"
|
||||
for i, io in enumerate(self.interfaces):
|
||||
repr += "\n " + str(io).replace("\n", "\n ")
|
||||
if i < len(self.interfaces) - 1: # Don't apply to last interface.
|
||||
repr += "\n&"
|
||||
return repr
|
||||
|
||||
def launch(self, *args, **kwargs):
|
||||
return self.compound_interface.launch(*args, **kwargs)
|
||||
|
||||
|
||||
class Sequential:
|
||||
def __init__(self, interfaces):
|
||||
fns = []
|
||||
for io in interfaces:
|
||||
fns.extend(io.predict)
|
||||
|
||||
def cascaded_fn(inp):
|
||||
out = inp
|
||||
for fn in fns:
|
||||
out = fn(out)
|
||||
return out
|
||||
|
||||
cascaded_fn.__name__ = " => ".join([f.__name__ for f in fns])
|
||||
|
||||
self.interfaces = interfaces
|
||||
self.predict = [cascaded_fn]
|
||||
self.input_interfaces = interfaces[0].input_interfaces
|
||||
self.output_interfaces = interfaces[-1].output_interfaces
|
||||
self.compound_interface = Interface(
|
||||
fn=cascaded_fn, inputs=self.input_interfaces, outputs=self.output_interfaces)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
repr = "Gradio Sequential Interface, consisting of:"
|
||||
repr += "\n-----------------------------------------"
|
||||
for i, io in enumerate(self.interfaces):
|
||||
repr += "\n " + str(io).replace("\n", "\n ")
|
||||
if i < len(self.interfaces) - 1: # Don't apply to last interface.
|
||||
repr += "\n=>"
|
||||
return repr
|
||||
|
||||
def launch(self, *args, **kwargs):
|
||||
return self.compound_interface.launch(*args, **kwargs)
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import requests
|
||||
from gradio.interface import Interface
|
||||
from gradio import inputs, outputs
|
||||
|
||||
|
||||
@ -23,40 +22,40 @@ def get_huggingface_interface(model_name, api_key):
|
||||
},
|
||||
'text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Question"),
|
||||
'preprocess': lambda x: x,
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["generated_text"],
|
||||
# 'examples': [['My name is Clara and I am']]
|
||||
},
|
||||
'summarization': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Summary"),
|
||||
'preprocess': lambda x: x,
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["summary_text"]
|
||||
},
|
||||
'translation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Translation"),
|
||||
'preprocess': lambda x: x,
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["translation_text"]
|
||||
},
|
||||
'text2text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: x,
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["generated_text"]
|
||||
},
|
||||
'text-classification': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': "label",
|
||||
'preprocess': lambda x: x,
|
||||
'outputs': outputs.Label(label="Class"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {'Negative': r[0][0]["score"],
|
||||
'Positive': r[0][1]["score"]}
|
||||
},
|
||||
'fill-mask': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': "label",
|
||||
'preprocess': lambda x: x,
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r}
|
||||
},
|
||||
'zero-shot-classification': {
|
||||
@ -79,6 +78,7 @@ def get_huggingface_interface(model_name, api_key):
|
||||
|
||||
def query_huggingface_api(*input):
|
||||
payload = pipeline['preprocess'](*input)
|
||||
payload.update({'options': {'wait_for_model': True}})
|
||||
data = json.dumps(payload)
|
||||
response = requests.request("POST", api_url, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
@ -127,10 +127,15 @@ def get_gradio_interface(model_name, api_key):
|
||||
|
||||
return interface_info
|
||||
|
||||
def load_interface(model, src, api_key=None, verbose=True):
|
||||
def load_interface(name, src=None, api_key=None):
|
||||
if src is None:
|
||||
tokens = name.split("/")
|
||||
assert len(tokens) > 1, "Either `src` parameter must be provided, or `name` must be formatted as \{src\}/\{repo name\}"
|
||||
src = tokens[0]
|
||||
name = "/".join(tokens[1:])
|
||||
assert src.lower() in repos, "parameter: src must be one of {}".format(repos.keys())
|
||||
interface_info = repos[src](model, api_key)
|
||||
return Interface(**interface_info)
|
||||
interface_info = repos[src](name, api_key)
|
||||
return interface_info
|
||||
|
||||
repos = {
|
||||
# for each repo, we have a method that returns the Interface given the model name & optionally an api_key
|
||||
|
@ -8,6 +8,7 @@ from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
from gradio.interpretation import quantify_difference_in_label
|
||||
from gradio.external import load_interface
|
||||
from gradio import encryptor
|
||||
import pkg_resources
|
||||
import requests
|
||||
@ -45,14 +46,20 @@ class Interface:
|
||||
return list(
|
||||
Interface.instances)
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
@classmethod
|
||||
def load(cls, name, src=None, api_key=None, **kwargs):
|
||||
interface_info = load_interface(name, src, api_key)
|
||||
interface_info.update(kwargs)
|
||||
return cls(**interface_info)
|
||||
|
||||
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
|
||||
examples_per_page=10, live=False,
|
||||
layout="horizontal", show_input=True, show_output=True,
|
||||
capture_session=False, interpretation=None,
|
||||
capture_session=False, interpretation=None, repeat_outputs_per_model=True,
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||
show_tips=True, embedding=None, flagging_dir="flagged", analytics_enabled=True):
|
||||
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True):
|
||||
|
||||
"""
|
||||
Parameters:
|
||||
@ -80,7 +87,6 @@ class Interface:
|
||||
flagging_dir (str): what to name the dir where flagged data is stored.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
"""
|
||||
|
||||
def get_input_instance(iface):
|
||||
if isinstance(iface, str):
|
||||
shortcut = InputComponent.get_all_shortcut_implementations()[iface]
|
||||
@ -103,6 +109,8 @@ class Interface:
|
||||
"`OutputComponent`"
|
||||
)
|
||||
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
if isinstance(inputs, list):
|
||||
self.input_interfaces = [get_input_instance(i) for i in inputs]
|
||||
else:
|
||||
@ -111,12 +119,13 @@ class Interface:
|
||||
self.output_interfaces = [get_output_instance(i) for i in outputs]
|
||||
else:
|
||||
self.output_interfaces = [get_output_instance(outputs)]
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
|
||||
self.output_interfaces *= len(fn)
|
||||
# self.original_output_interfaces = copy.copy(self.output_interfaces)
|
||||
if repeat_outputs_per_model:
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.function_names = [func.__name__ for func in fn]
|
||||
self.__name__ = ", ".join(self.function_names)
|
||||
self.verbose = verbose
|
||||
self.status = "OFF"
|
||||
self.live = live
|
||||
@ -135,6 +144,8 @@ class Interface:
|
||||
article = markdown2.markdown(article)
|
||||
self.article = article
|
||||
self.thumbnail = thumbnail
|
||||
self.height = height
|
||||
self.width = width
|
||||
if css is not None and os.path.exists(css):
|
||||
with open(css) as css_file:
|
||||
self.css = css_file.read()
|
||||
@ -192,12 +203,15 @@ class Interface:
|
||||
data=data, timeout=3)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def __call__(self, params_per_function):
|
||||
return self.predict[0](params_per_function)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
repr = "Gradio interface for function: {}".format(",".join(fn.__name__ for fn in self.predict))
|
||||
repr = "Gradio Interface for: {}".format(", ".join(fn.__name__ for fn in self.predict))
|
||||
repr += "\n" + "-"*len(repr)
|
||||
repr += "\ninputs:"
|
||||
for component in self.input_interfaces:
|
||||
@ -387,7 +401,7 @@ class Interface:
|
||||
print("PASSED")
|
||||
continue
|
||||
|
||||
def launch(self, inline=None, inbrowser=None, share=False, debug=False, auth=None, auth_message=None, private_endpoint=None, prevent_thread_lock=False):
|
||||
def launch(self, inline=None, inbrowser=None, share=False, debug=False, auth=None, auth_message=None, private_endpoint=None):
|
||||
"""
|
||||
Parameters:
|
||||
inline (bool): whether to display in the interface inline on python notebooks.
|
||||
@ -477,9 +491,9 @@ class Interface:
|
||||
if share:
|
||||
while not networking.url_ok(share_url):
|
||||
time.sleep(1)
|
||||
display(IFrame(share_url, width=900, height=500))
|
||||
display(IFrame(share_url, width=self.width, height=self.height))
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=900, height=500))
|
||||
display(IFrame(path_to_local_server, width=self.width, height=self.height))
|
||||
except ImportError:
|
||||
pass # IPython is not available so does not print inline.
|
||||
|
||||
@ -493,12 +507,8 @@ class Interface:
|
||||
while True:
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
is_in_interactive_mode = bool(getattr(sys, 'ps1', sys.flags.interactive))
|
||||
if not prevent_thread_lock and not is_in_interactive_mode:
|
||||
print("going to lock thread and run in foreground ...")
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
|
||||
return app, path_to_local_server, share_url, thread
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
|
||||
def integrate(self, comet_ml=None):
|
||||
@ -508,6 +518,7 @@ class Interface:
|
||||
else:
|
||||
comet_ml.log_text(self.local_url)
|
||||
comet_ml.end()
|
||||
|
||||
|
||||
def show_tip(io):
|
||||
if not(io.show_tips):
|
||||
|
77
gradio/transforms.py
Normal file
77
gradio/transforms.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""
|
||||
Ways to transform interfaces to produce new interfaces
|
||||
"""
|
||||
from gradio.interface import Interface
|
||||
|
||||
def compare(*interfaces, **options):
|
||||
fns = []
|
||||
outputs = []
|
||||
|
||||
for io in interfaces:
|
||||
fns.extend(io.predict)
|
||||
outputs.extend(io.output_interfaces)
|
||||
|
||||
return Interface(fn=fns, inputs=interfaces[0].input_interfaces, outputs=outputs,
|
||||
repeat_outputs_per_model=False, **options)
|
||||
|
||||
# class Compare:
|
||||
# def __init__(self, interfaces):
|
||||
|
||||
# self.interfaces = interfaces
|
||||
# self.predict = fns
|
||||
# self.compound_interface = Interface(
|
||||
# fn=fns, inputs=interfaces[0].input_interfaces, outputs=interfaces[0].output_interfaces)
|
||||
|
||||
# def __str__(self):
|
||||
# return self.__repr__()
|
||||
|
||||
# def __repr__(self):
|
||||
# repr = "Gradio Comparative Interface, consisting of:"
|
||||
# repr += "\n-----------------------------------------"
|
||||
# for i, io in enumerate(self.interfaces):
|
||||
# repr += "\n " + str(io).replace("\n", "\n ")
|
||||
# if i < len(self.interfaces) - 1: # Don't apply to last interface.
|
||||
# repr += "\n&"
|
||||
# return repr
|
||||
|
||||
# def launch(self, *args, **kwargs):
|
||||
# return self.compound_interface.launch(*args, **kwargs)
|
||||
|
||||
def connect():
|
||||
pass
|
||||
|
||||
# class Cascade:
|
||||
# def __init__(self, interfaces):
|
||||
# fns = []
|
||||
# for io in interfaces:
|
||||
# fns.extend(io.predict)
|
||||
|
||||
# def cascaded_process(inp):
|
||||
# out = inp
|
||||
# for fn in fns:
|
||||
# out = fn(out)
|
||||
# return out
|
||||
|
||||
# cascaded_fn.__name__ = " => ".join([f.__name__ for f in fns])
|
||||
|
||||
# self.interfaces = interfaces
|
||||
# self.predict = [cascaded_fn]
|
||||
# self.input_interfaces = interfaces[0].input_interfaces
|
||||
# self.output_interfaces = interfaces[-1].output_interfaces
|
||||
# self.compound_interface = Interface(
|
||||
# fn=cascaded_fn, inputs=self.input_interfaces, outputs=self.output_interfaces)
|
||||
|
||||
# def __str__(self):
|
||||
# return self.__repr__()
|
||||
|
||||
# def __repr__(self):
|
||||
# repr = "Gradio Cascaded Interface, consisting of:"
|
||||
# repr += "\n-----------------------------------------"
|
||||
# for i, io in enumerate(self.interfaces):
|
||||
# repr += "\n " + str(io).replace("\n", "\n ")
|
||||
# if i < len(self.interfaces) - 1: # Don't apply to last interface.
|
||||
# repr += "\n=>"
|
||||
# return repr
|
||||
|
||||
# def launch(self, *args, **kwargs):
|
||||
# return self.compound_interface.launch(*args, **kwargs)
|
Loading…
Reference in New Issue
Block a user