From 30ed40afb75eef3e67398410dc4bd588ed9e1b86 Mon Sep 17 00:00:00 2001 From: AK391 <81195143+AK391@users.noreply.github.com> Date: Thu, 11 Nov 2021 22:43:27 -0500 Subject: [PATCH] tests for integration and fixes bug --- gradio/interface.py | 2 +- test/test_interfaces.py | 53 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/gradio/interface.py b/gradio/interface.py index 11fd49b231..9e7f7e74ba 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -692,7 +692,7 @@ class Interface: mlflow.log_param("Gradio Interface Local Link", self.local_url) if self.analytics_enabled: - if not analytics_integration: + if analytics_integration: data = {'integration': analytics_integration} try: requests.post(analytics_url + diff --git a/test/test_interfaces.py b/test/test_interfaces.py index fc65409ae0..4a6d2192ad 100644 --- a/test/test_interfaces.py +++ b/test/test_interfaces.py @@ -7,6 +7,9 @@ import sys from contextlib import contextmanager import io import threading +from comet_ml import Experiment +import mlflow +import wandb os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -178,6 +181,56 @@ class TestInterface(unittest.TestCase): interface.launch(inline=True, share=True, prevent_thread_lock=True) self.assertEqual(mock_display.call_count, 2) interface.close() + + @mock.patch('comet_ml.Experiment') + def test_integration_comet(self, mock_experiment): + experiment = mock_experiment() + experiment.log_text = mock.MagicMock() + experiment.log_other = mock.MagicMock() + interface = Interface(lambda x: x, "textbox", "label") + interface.launch(prevent_thread_lock=True) + interface.integrate(comet_ml=experiment) + experiment.log_text.assert_called_with('gradio: ' + interface.local_url) + interface.share_url = 'tmp' # used to avoid creating real share links. + interface.integrate(comet_ml=experiment) + experiment.log_text.assert_called_with('gradio: ' + interface.share_url) + self.assertEqual(experiment.log_other.call_count, 2) + interface.share_url = None + interface.close() + + def test_integration_mlflow(self): + mlflow.log_param = mock.MagicMock() + interface = Interface(lambda x: x, "textbox", "label") + interface.launch(prevent_thread_lock=True) + interface.integrate(mlflow=mlflow) + mlflow.log_param.assert_called_with("Gradio Interface Local Link", interface.local_url) + interface.share_url = 'tmp' # used to avoid creating real share links. + interface.integrate(mlflow=mlflow) + mlflow.log_param.assert_called_with("Gradio Interface Share Link", interface.share_url) + interface.share_url = None + interface.close() + def test_integration_wandb(self): + with captured_output() as (out, err): + wandb.log = mock.MagicMock() + wandb.Html = mock.MagicMock() + interface = Interface(lambda x: x, "textbox", "label") + interface.integrate(wandb=wandb) + self.assertEqual(out.getvalue().strip(), "The WandB integration requires you to `launch(share=True)` first.") + interface.share_url = 'tmp' + interface.integrate(wandb=wandb) + wandb.log.assert_called_once() + + @mock.patch('requests.post') + def test_integration_analytics(self, mock_post): + mlflow.log_param = mock.MagicMock() + interface = Interface(lambda x: x, "textbox", "label") + interface.analytics_enabled = True + interface.integrate(mlflow=mlflow) + mock_post.assert_called_once() + + + + if __name__ == '__main__': unittest.main() \ No newline at end of file