minor cleanup

This commit is contained in:
Abubakar Abid 2021-12-27 15:20:14 -06:00
parent 173b382e1b
commit 1dd1298112

View File

@ -63,11 +63,9 @@ class Interface:
(gradio.Interface): a Gradio Interface object for the given model
"""
interface_info = load_interface(name, src, api_key, alias)
# create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable
# and that can cause some issues since the internal prediction function may rely on the original interface_info dict
kwargs = dict(interface_info, **kwargs)
interface = cls(**kwargs)
interface.api_mode = True # set api mode to true so that the interface will not preprocess/postprocess
interface.api_mode = True # So interface doesn't run pre/postprocess.
return interface
@classmethod
@ -76,7 +74,8 @@ class Interface:
pipeline: transformers.Pipeline,
**kwargs) -> Interface:
"""
Class method to construct an Interface from a Hugging Face transformers.Pipeline.
Construct an Interface from a Hugging Face transformers.Pipeline.
Parameters:
pipeline (transformers.Pipeline):
Returns:
(gradio.Interface): a Gradio Interface object from the given Pipeline
@ -119,8 +118,7 @@ class Interface:
analytics_enabled: Optional[bool] = None,
enable_queue=None,
api_mode=None,
flagging_callback: FlaggingCallback = CSVLogger()
):
flagging_callback: FlaggingCallback = CSVLogger()):
"""
Parameters:
fn (Union[Callable, List[Callable]]): the function to wrap an interface around.
@ -221,7 +219,8 @@ class Interface:
self.simple_server = None
self.allow_screenshot = allow_screenshot
# For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True
# For allow_flagging and analytics_enabled: (1) first check for
# parameter, (2) check for environment variable, (3) default to True
self.analytics_enabled = analytics_enabled if analytics_enabled is not None else os.getenv("GRADIO_ANALYTICS_ENABLED", "True")=="True"
self.allow_flagging = allow_flagging if allow_flagging is not None else os.getenv("GRADIO_ALLOW_FLAGGING", "True")=="True"
self.flagging_options = flagging_options
@ -318,14 +317,16 @@ class Interface:
called_directly: bool = False
) -> List[Any] | Tuple[List[Any], List[float]]:
"""
This is the method that actually runs the prediction function with the given (processed) inputs.
Runs the prediction function with the given (already processed) inputs.
Parameters:
processed_input (list): A list of processed inputs.
return_duration (bool): Whether to return the duration of the prediction.
called_directly (bool): Whether the prediction is being called directly (i.e. as a function, not through the GUI).
called_directly (bool): Whether the prediction is being called
directly (i.e. as a function, not through the GUI).
Returns:
predictions (list): A list of predictions (not post-processed).
durations (list): A list of durations for each prediction (only if `return_duration` is True).
durations (list): A list of durations for each prediction
(only returned if `return_duration` is True).
"""
if self.api_mode: # Serialize the input
processed_input = [input_component.serialize(processed_input[i], called_directly)
@ -362,6 +363,8 @@ class Interface:
raw_input: List[Any]
) -> Tuple[List[Any], List[float]]:
"""
First preprocesses the input, then runs prediction using
self.run_prediction(), then postprocesses the output.
Parameters:
raw_input: a list of raw inputs to process and apply the prediction(s) on.
Returns:
@ -369,11 +372,13 @@ class Interface:
duration: a list of time deltas measuring inference time for each prediction fn.
"""
processed_input = [input_component.preprocess(raw_input[i])
for i, input_component in enumerate(self.input_components)]
for i, input_component in enumerate(
self.input_components)]
predictions, durations = self.run_prediction(
processed_input, return_duration=True)
processed_output = [output_component.postprocess(predictions[i]) if predictions[i] is not None else None
for i, output_component in enumerate(self.output_components)]
processed_output = [output_component.postprocess(
predictions[i]) if predictions[i] is not None else None
for i, output_component in enumerate(self.output_components)]
return processed_output, durations
def interpret(
@ -401,13 +406,12 @@ class Interface:
def test_launch(self) -> None:
for predict_fn in self.predict:
print("Test launch: {}()...".format(predict_fn.__name__), end=' ')
raw_input = []
for input_component in self.input_components:
if input_component.test_input is None: # If no test input is defined for that input interface
if input_component.test_input is None:
print("SKIPPED")
break
else: # If a test input is defined for each interface object
else:
raw_input.append(input_component.test_input)
else:
self.process(raw_input)
@ -465,8 +469,8 @@ class Interface:
self.auth_message = auth_message
self.show_tips = show_tips
self.show_error = show_error
self.height = self.height or height # if height is not set in constructor, use the one provided here
self.width = self.width or width # if width is not set in constructor, use the one provided here
self.height = self.height or height
self.width = self.width or width
if self.encrypt is None:
self.encrypt = encrypt
@ -479,7 +483,6 @@ class Interface:
if self.allow_flagging:
self.flagging_callback.setup(self.flagging_dir)
# Launch local flask server
server_port, path_to_local_server, app, thread, server = networking.start_server(
self, server_name, server_port, self.auth)
self.local_url = path_to_local_server
@ -528,7 +531,6 @@ class Interface:
print(strings.en["PUBLIC_SHARE_TRUE"])
share_url = None
# Open a browser tab with the interface.
if inbrowser:
link = share_url if share else path_to_local_server
webbrowser.open(link)
@ -547,7 +549,7 @@ class Interface:
display(IFrame(path_to_local_server,
width=self.width, height=self.height))
except ImportError:
pass # IPython is not available so does not print inline.
pass
data = {
'launch_method': 'browser' if inbrowser else 'inline',
@ -590,9 +592,7 @@ class Interface:
self.server_thread.join()
if verbose:
print("Closing server running on port: {}".format(self.server_port))
except AttributeError: # can't close if not running
pass
except OSError: # sometimes OSError is thrown when shutting down
except (AttributeError, OSError): # can't close if not running
pass
def integrate(
@ -640,7 +640,6 @@ class Interface:
def close_all(verbose: bool = True) -> None:
# Tries to close all running interfaces, but method is a little flaky.
for io in Interface.get_instances():
io.close(verbose)