integrate() method moved to Blocks (#1776)

* integrating blocks

* formatting

* added tests

* tests

* formatting

* added integrate() to docs

* typing

* typing
This commit is contained in:
Abubakar Abid 2022-07-19 06:47:40 +01:00 committed by GitHub
parent a44f8f7780
commit 4149d00822
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 144 additions and 59 deletions

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import copy
import getpass
import inspect
@ -9,6 +8,7 @@ import random
import sys
import time
import webbrowser
from types import ModuleType
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
import anyio
@ -23,10 +23,12 @@ from gradio.utils import component_or_layout_class, delete_none
set_documentation_group("blocks")
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import comet_ml
import mlflow
import wandb
from fastapi.applications import FastAPI
from gradio.components import Component, StatusTracker
from gradio.routes import PredictBody
class Block:
@ -304,7 +306,10 @@ class Blocks(BlockContext):
self.mode = mode
self.is_running = False
self.local_url = None
self.share_url = None
self.width = None
self.height = None
self.ip_address = utils.get_local_ip_address()
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
@ -929,6 +934,60 @@ class Blocks(BlockContext):
return self.server_app, self.local_url, self.share_url
def integrate(
self,
comet_ml: comet_ml.Experiment = None,
wandb: ModuleType("wandb") = None,
mlflow: ModuleType("mlflow") = None,
) -> None:
"""
A catch-all method for integrating with other libraries. This method should be run after launch()
Parameters:
comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
"""
analytics_integration = ""
if comet_ml is not None:
analytics_integration = "CometML"
comet_ml.log_other("Created from", "Gradio")
if self.share_url is not None:
comet_ml.log_text("gradio: " + self.share_url)
comet_ml.end()
else:
comet_ml.log_text("gradio: " + self.local_url)
comet_ml.end()
if wandb is not None:
analytics_integration = "WandB"
if self.share_url is not None:
wandb.log(
{
"Gradio panel": wandb.Html(
'<iframe src="'
+ self.share_url
+ '" width="'
+ str(self.width)
+ '" height="'
+ str(self.height)
+ '" frameBorder="0"></iframe>'
)
}
)
else:
print(
"The WandB integration requires you to "
"`launch(share=True)` first."
)
if mlflow is not None:
analytics_integration = "MLFlow"
if self.share_url is not None:
mlflow.log_param("Gradio Interface Share Link", self.share_url)
else:
mlflow.log_param("Gradio Interface Local Link", self.local_url)
if self.analytics_enabled and analytics_integration:
data = {"integration": analytics_integration}
utils.integration_analytics(data)
def close(self, verbose: bool = True) -> None:
"""
Closes the Interface that was launched and frees the port.

View File

@ -12,6 +12,15 @@ def set_documentation_group(m):
def document(*fns):
"""
Defines the @document decorator which adds classes or functions to the Gradio
documentation at www.gradio.app/docs.
Usage examples:
- Put @document() above a class to document the class and its constructor.
- Put @document(fn1, fn2) above a class to also document the class methods fn1 and fn2.
"""
def inner_doc(cls):
global documentation_group
classes_to_document[documentation_group].append((cls, fns))

View File

@ -5,12 +5,9 @@ including various methods for constructing an interface and then launching it.
from __future__ import annotations
import copy
import csv
import inspect
import json
import os
import random
import re
import warnings
import weakref
@ -45,7 +42,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import transformers
@document("launch", "load", "from_pipeline")
@document("launch", "load", "from_pipeline", "integrate")
class Interface(Blocks):
"""
The Interface class is a high-level abstraction that allows you to create a
@ -722,59 +719,6 @@ class Interface(Blocks):
self.process(raw_input)
print("PASSED")
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
"""
A catch-all method for integrating with other libraries.
Should be run after launch()
Parameters:
comet_ml (Experiment): If a comet_ml Experiment object is provided,
will integrate with the experiment and appear on Comet dashboard
wandb (module): If the wandb module is provided, will integrate
with it and appear on WandB dashboard
mlflow (module): If the mlflow module is provided, will integrate
with the experiment and appear on ML Flow dashboard
"""
analytics_integration = ""
if comet_ml is not None:
analytics_integration = "CometML"
comet_ml.log_other("Created from", "Gradio")
if self.share_url is not None:
comet_ml.log_text("gradio: " + self.share_url)
comet_ml.end()
else:
comet_ml.log_text("gradio: " + self.local_url)
comet_ml.end()
if wandb is not None:
analytics_integration = "WandB"
if self.share_url is not None:
wandb.log(
{
"Gradio panel": wandb.Html(
'<iframe src="'
+ self.share_url
+ '" width="'
+ str(self.width)
+ '" height="'
+ str(self.height)
+ '" frameBorder="0"></iframe>'
)
}
)
else:
print(
"The WandB integration requires you to "
"`launch(share=True)` first."
)
if mlflow is not None:
analytics_integration = "MLFlow"
if self.share_url is not None:
mlflow.log_param("Gradio Interface Share Link", self.share_url)
else:
mlflow.log_param("Gradio Interface Local Link", self.local_url)
if self.analytics_enabled and analytics_integration:
data = {"integration": analytics_integration}
utils.integration_analytics(data)
@document()
class TabbedInterface(Blocks):

View File

@ -1,10 +1,16 @@
import asyncio
import io
import random
import sys
import time
import unittest
import unittest.mock as mock
from contextlib import contextmanager
from unittest.mock import patch
import mlflow
import pytest
import wandb
import gradio as gr
from gradio.routes import PredictBody
@ -14,6 +20,17 @@ from gradio.utils import assert_configs_are_equivalent_besides_ids
pytest_plugins = ("pytest_asyncio",)
@contextmanager
def captured_output():
new_out, new_err = io.StringIO(), io.StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
class TestBlocks(unittest.TestCase):
maxDiff = None
@ -123,6 +140,62 @@ class TestBlocks(unittest.TestCase):
assert difference >= 0.01
assert result
def test_integration_wandb(self):
with captured_output() as (out, err):
wandb.log = mock.MagicMock()
wandb.Html = mock.MagicMock()
demo = gr.Blocks()
with demo:
gr.Textbox("Hi there!")
demo.integrate(wandb=wandb)
self.assertEqual(
out.getvalue().strip(),
"The WandB integration requires you to `launch(share=True)` first.",
)
demo.share_url = "tmp"
demo.integrate(wandb=wandb)
wandb.log.assert_called_once()
@mock.patch("comet_ml.Experiment")
def test_integration_comet(self, mock_experiment):
experiment = mock_experiment()
experiment.log_text = mock.MagicMock()
experiment.log_other = mock.MagicMock()
demo = gr.Blocks()
with demo:
gr.Textbox("Hi there!")
demo.launch(prevent_thread_lock=True)
demo.integrate(comet_ml=experiment)
experiment.log_text.assert_called_with("gradio: " + demo.local_url)
demo.share_url = "tmp" # used to avoid creating real share links.
demo.integrate(comet_ml=experiment)
experiment.log_text.assert_called_with("gradio: " + demo.share_url)
self.assertEqual(experiment.log_other.call_count, 2)
demo.share_url = None
demo.close()
def test_integration_mlflow(self):
mlflow.log_param = mock.MagicMock()
demo = gr.Blocks()
with demo:
gr.Textbox("Hi there!")
demo.launch(prevent_thread_lock=True)
demo.integrate(mlflow=mlflow)
mlflow.log_param.assert_called_with(
"Gradio Interface Local Link", demo.local_url
)
demo.share_url = "tmp" # used to avoid creating real share links.
demo.integrate(mlflow=mlflow)
mlflow.log_param.assert_called_with(
"Gradio Interface Share Link", demo.share_url
)
demo.share_url = None
demo.close()
if __name__ == "__main__":
unittest.main()