diff --git a/gradio/utils.py b/gradio/utils.py index bf4d988e4d..65c2988cc9 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1,3 +1,5 @@ +import json.decoder + import requests import pkg_resources from distutils.version import StrictVersion @@ -15,8 +17,15 @@ def version_check(): "is available, please upgrade.".format( current_pkg_version, latest_pkg_version)) print('--------') - except: # TODO(abidlabs): don't catch all exceptions - pass + + except pkg_resources.DistributionNotFound: + raise RuntimeError("gradio is not setup or installed properly. Unable to get version info.") + except json.decoder.JSONDecodeError: + raise RuntimeWarning("Unable to parse version details from package URL.") + except KeyError: + raise RuntimeWarning("Package URL does not contain version info.") + except ConnectionError: + raise RuntimeWarning("Unable to connect with package URL to collect version info.") def error_analytics(type): diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000000..bc1187cd87 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,46 @@ +from gradio.utils import * +import unittest +import pkg_resources +import unittest.mock as mock + + +class TestUtils(unittest.TestCase): + @mock.patch("pkg_resources.require") + def test_should_fail_with_distribution_not_found(self, mock_require): + + mock_require.side_effect = pkg_resources.DistributionNotFound() + + with self.assertRaises(RuntimeError) as e: + version_check() + self.assertEqual(str(e.exception), "gradio is not setup or installed properly. Unable to get version info.") + + @mock.patch("requests.get") + def test_should_warn_with_unable_to_parse(self, mock_get): + + mock_get.side_effect = json.decoder.JSONDecodeError("Expecting value", "", 0) + + with self.assertRaises(RuntimeWarning) as e: + version_check() + self.assertEqual(str(e.exception), "Unable to parse version details from package URL.") + + @mock.patch("requests.get") + def test_should_warn_with_connection_error(self, mock_get): + + mock_get.side_effect = ConnectionError() + + with self.assertRaises(RuntimeWarning) as e: + version_check() + self.assertEqual(str(e.exception), "Unable to connect with package URL to collect version info.") + + @mock.patch("requests.Response.json") + def test_should_warn_url_not_having_version(self, mock_json): + + mock_json.return_value = {"foo": "bar"} + + with self.assertRaises(RuntimeWarning) as e: + version_check() + self.assertEqual(str(e.exception), "Package URL does not contain version info.") + + +if __name__ == '__main__': + unittest.main()