mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
integrate()
method moved to Blocks
(#1776)
* integrating blocks * formatting * added tests * tests * formatting * added integrate() to docs * typing * typing
This commit is contained in:
parent
a44f8f7780
commit
4149d00822
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import getpass
|
||||
import inspect
|
||||
@ -9,6 +8,7 @@ import random
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import anyio
|
||||
@ -23,10 +23,12 @@ from gradio.utils import component_or_layout_class, delete_none
|
||||
set_documentation_group("blocks")
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import comet_ml
|
||||
import mlflow
|
||||
import wandb
|
||||
from fastapi.applications import FastAPI
|
||||
|
||||
from gradio.components import Component, StatusTracker
|
||||
from gradio.routes import PredictBody
|
||||
|
||||
|
||||
class Block:
|
||||
@ -304,7 +306,10 @@ class Blocks(BlockContext):
|
||||
self.mode = mode
|
||||
|
||||
self.is_running = False
|
||||
self.local_url = None
|
||||
self.share_url = None
|
||||
self.width = None
|
||||
self.height = None
|
||||
|
||||
self.ip_address = utils.get_local_ip_address()
|
||||
self.is_space = True if os.getenv("SYSTEM") == "spaces" else False
|
||||
@ -929,6 +934,60 @@ class Blocks(BlockContext):
|
||||
|
||||
return self.server_app, self.local_url, self.share_url
|
||||
|
||||
def integrate(
|
||||
self,
|
||||
comet_ml: comet_ml.Experiment = None,
|
||||
wandb: ModuleType("wandb") = None,
|
||||
mlflow: ModuleType("mlflow") = None,
|
||||
) -> None:
|
||||
"""
|
||||
A catch-all method for integrating with other libraries. This method should be run after launch()
|
||||
Parameters:
|
||||
comet_ml: If a comet_ml Experiment object is provided, will integrate with the experiment and appear on Comet dashboard
|
||||
wandb: If the wandb module is provided, will integrate with it and appear on WandB dashboard
|
||||
mlflow: If the mlflow module is provided, will integrate with the experiment and appear on ML Flow dashboard
|
||||
"""
|
||||
analytics_integration = ""
|
||||
if comet_ml is not None:
|
||||
analytics_integration = "CometML"
|
||||
comet_ml.log_other("Created from", "Gradio")
|
||||
if self.share_url is not None:
|
||||
comet_ml.log_text("gradio: " + self.share_url)
|
||||
comet_ml.end()
|
||||
else:
|
||||
comet_ml.log_text("gradio: " + self.local_url)
|
||||
comet_ml.end()
|
||||
if wandb is not None:
|
||||
analytics_integration = "WandB"
|
||||
if self.share_url is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"Gradio panel": wandb.Html(
|
||||
'<iframe src="'
|
||||
+ self.share_url
|
||||
+ '" width="'
|
||||
+ str(self.width)
|
||||
+ '" height="'
|
||||
+ str(self.height)
|
||||
+ '" frameBorder="0"></iframe>'
|
||||
)
|
||||
}
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"The WandB integration requires you to "
|
||||
"`launch(share=True)` first."
|
||||
)
|
||||
if mlflow is not None:
|
||||
analytics_integration = "MLFlow"
|
||||
if self.share_url is not None:
|
||||
mlflow.log_param("Gradio Interface Share Link", self.share_url)
|
||||
else:
|
||||
mlflow.log_param("Gradio Interface Local Link", self.local_url)
|
||||
if self.analytics_enabled and analytics_integration:
|
||||
data = {"integration": analytics_integration}
|
||||
utils.integration_analytics(data)
|
||||
|
||||
def close(self, verbose: bool = True) -> None:
|
||||
"""
|
||||
Closes the Interface that was launched and frees the port.
|
||||
|
@ -12,6 +12,15 @@ def set_documentation_group(m):
|
||||
|
||||
|
||||
def document(*fns):
|
||||
"""
|
||||
Defines the @document decorator which adds classes or functions to the Gradio
|
||||
documentation at www.gradio.app/docs.
|
||||
|
||||
Usage examples:
|
||||
- Put @document() above a class to document the class and its constructor.
|
||||
- Put @document(fn1, fn2) above a class to also document the class methods fn1 and fn2.
|
||||
"""
|
||||
|
||||
def inner_doc(cls):
|
||||
global documentation_group
|
||||
classes_to_document[documentation_group].append((cls, fns))
|
||||
|
@ -5,12 +5,9 @@ including various methods for constructing an interface and then launching it.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import csv
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import warnings
|
||||
import weakref
|
||||
@ -45,7 +42,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import transformers
|
||||
|
||||
|
||||
@document("launch", "load", "from_pipeline")
|
||||
@document("launch", "load", "from_pipeline", "integrate")
|
||||
class Interface(Blocks):
|
||||
"""
|
||||
The Interface class is a high-level abstraction that allows you to create a
|
||||
@ -722,59 +719,6 @@ class Interface(Blocks):
|
||||
self.process(raw_input)
|
||||
print("PASSED")
|
||||
|
||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
|
||||
"""
|
||||
A catch-all method for integrating with other libraries.
|
||||
Should be run after launch()
|
||||
Parameters:
|
||||
comet_ml (Experiment): If a comet_ml Experiment object is provided,
|
||||
will integrate with the experiment and appear on Comet dashboard
|
||||
wandb (module): If the wandb module is provided, will integrate
|
||||
with it and appear on WandB dashboard
|
||||
mlflow (module): If the mlflow module is provided, will integrate
|
||||
with the experiment and appear on ML Flow dashboard
|
||||
"""
|
||||
analytics_integration = ""
|
||||
if comet_ml is not None:
|
||||
analytics_integration = "CometML"
|
||||
comet_ml.log_other("Created from", "Gradio")
|
||||
if self.share_url is not None:
|
||||
comet_ml.log_text("gradio: " + self.share_url)
|
||||
comet_ml.end()
|
||||
else:
|
||||
comet_ml.log_text("gradio: " + self.local_url)
|
||||
comet_ml.end()
|
||||
if wandb is not None:
|
||||
analytics_integration = "WandB"
|
||||
if self.share_url is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"Gradio panel": wandb.Html(
|
||||
'<iframe src="'
|
||||
+ self.share_url
|
||||
+ '" width="'
|
||||
+ str(self.width)
|
||||
+ '" height="'
|
||||
+ str(self.height)
|
||||
+ '" frameBorder="0"></iframe>'
|
||||
)
|
||||
}
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"The WandB integration requires you to "
|
||||
"`launch(share=True)` first."
|
||||
)
|
||||
if mlflow is not None:
|
||||
analytics_integration = "MLFlow"
|
||||
if self.share_url is not None:
|
||||
mlflow.log_param("Gradio Interface Share Link", self.share_url)
|
||||
else:
|
||||
mlflow.log_param("Gradio Interface Local Link", self.local_url)
|
||||
if self.analytics_enabled and analytics_integration:
|
||||
data = {"integration": analytics_integration}
|
||||
utils.integration_analytics(data)
|
||||
|
||||
|
||||
@document()
|
||||
class TabbedInterface(Blocks):
|
||||
|
@ -1,10 +1,16 @@
|
||||
import asyncio
|
||||
import io
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import mlflow
|
||||
import pytest
|
||||
import wandb
|
||||
|
||||
import gradio as gr
|
||||
from gradio.routes import PredictBody
|
||||
@ -14,6 +20,17 @@ from gradio.utils import assert_configs_are_equivalent_besides_ids
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def captured_output():
|
||||
new_out, new_err = io.StringIO(), io.StringIO()
|
||||
old_out, old_err = sys.stdout, sys.stderr
|
||||
try:
|
||||
sys.stdout, sys.stderr = new_out, new_err
|
||||
yield sys.stdout, sys.stderr
|
||||
finally:
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
|
||||
class TestBlocks(unittest.TestCase):
|
||||
maxDiff = None
|
||||
|
||||
@ -123,6 +140,62 @@ class TestBlocks(unittest.TestCase):
|
||||
assert difference >= 0.01
|
||||
assert result
|
||||
|
||||
def test_integration_wandb(self):
|
||||
with captured_output() as (out, err):
|
||||
wandb.log = mock.MagicMock()
|
||||
wandb.Html = mock.MagicMock()
|
||||
demo = gr.Blocks()
|
||||
with demo:
|
||||
gr.Textbox("Hi there!")
|
||||
demo.integrate(wandb=wandb)
|
||||
|
||||
self.assertEqual(
|
||||
out.getvalue().strip(),
|
||||
"The WandB integration requires you to `launch(share=True)` first.",
|
||||
)
|
||||
demo.share_url = "tmp"
|
||||
demo.integrate(wandb=wandb)
|
||||
wandb.log.assert_called_once()
|
||||
|
||||
@mock.patch("comet_ml.Experiment")
|
||||
def test_integration_comet(self, mock_experiment):
|
||||
experiment = mock_experiment()
|
||||
experiment.log_text = mock.MagicMock()
|
||||
experiment.log_other = mock.MagicMock()
|
||||
|
||||
demo = gr.Blocks()
|
||||
with demo:
|
||||
gr.Textbox("Hi there!")
|
||||
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
demo.integrate(comet_ml=experiment)
|
||||
experiment.log_text.assert_called_with("gradio: " + demo.local_url)
|
||||
demo.share_url = "tmp" # used to avoid creating real share links.
|
||||
demo.integrate(comet_ml=experiment)
|
||||
experiment.log_text.assert_called_with("gradio: " + demo.share_url)
|
||||
self.assertEqual(experiment.log_other.call_count, 2)
|
||||
demo.share_url = None
|
||||
demo.close()
|
||||
|
||||
def test_integration_mlflow(self):
|
||||
mlflow.log_param = mock.MagicMock()
|
||||
demo = gr.Blocks()
|
||||
with demo:
|
||||
gr.Textbox("Hi there!")
|
||||
|
||||
demo.launch(prevent_thread_lock=True)
|
||||
demo.integrate(mlflow=mlflow)
|
||||
mlflow.log_param.assert_called_with(
|
||||
"Gradio Interface Local Link", demo.local_url
|
||||
)
|
||||
demo.share_url = "tmp" # used to avoid creating real share links.
|
||||
demo.integrate(mlflow=mlflow)
|
||||
mlflow.log_param.assert_called_with(
|
||||
"Gradio Interface Share Link", demo.share_url
|
||||
)
|
||||
demo.share_url = None
|
||||
demo.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user