mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Do not do any analytics requests if analytics are disabled (#4194)
* Move analytics-related bits to gradio.analytics * Do not do any analytics requests if analytics are disabled * Remove unused log_feature_analytics * removed redundant analytics, rewrote * renamed * save * fixed test' --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
96b38fda07
commit
8b72e9e127
@ -8,6 +8,8 @@
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
- Fix "TypeError: issubclass() arg 1 must be a class" When use Optional[Types] by [@lingfengchencn](https://github.com/lingfengchencn) in [PR 4200](https://github.com/gradio-app/gradio/pull/4200).
|
||||
- Gradio will no longer send any analytics if analytics are disabled with the GRADIO_ANALYTICS_ENABLED environment variable. By [@akx](https://github.com/akx) in [PR 4194](https://github.com/gradio-app/gradio/pull/4194)
|
||||
- The deprecation warnings for kwargs now show the actual stack level for the invocation, by [@akx](https://github.com/akx) in [PR 4203](https://github.com/gradio-app/gradio/pull/4203).
|
||||
- Fix "TypeError: issubclass() arg 1 must be a class" When use Optional[Types] by [@lingfengchencn](https://github.com/lingfengchencn) in [PR 4200](https://github.com/gradio-app/gradio/pull/4200).
|
||||
|
||||
|
187
gradio/analytics.py
Normal file
187
gradio/analytics.py
Normal file
@ -0,0 +1,187 @@
|
||||
""" Functions related to analytics and telemetry. """
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import threading
|
||||
import warnings
|
||||
from distutils.version import StrictVersion
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
import gradio
|
||||
from gradio.context import Context
|
||||
from gradio.utils import GRADIO_VERSION
|
||||
|
||||
ANALYTICS_URL = "https://api.gradio.app/"
|
||||
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
|
||||
|
||||
|
||||
def analytics_enabled() -> bool:
|
||||
"""
|
||||
Returns: True if analytics are enabled, False otherwise.
|
||||
"""
|
||||
return os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
|
||||
|
||||
|
||||
def _do_analytics_request(url: str, data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(url, data=data, timeout=5)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def version_check():
|
||||
if not analytics_enabled():
|
||||
return
|
||||
try:
|
||||
version_data = pkgutil.get_data(__name__, "version.txt")
|
||||
if not version_data:
|
||||
raise FileNotFoundError
|
||||
current_pkg_version = version_data.decode("ascii").strip()
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL, timeout=3).json()[
|
||||
"version"
|
||||
]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print(
|
||||
f"IMPORTANT: You are using gradio version {current_pkg_version}, "
|
||||
f"however version {latest_pkg_version} is available, please upgrade."
|
||||
)
|
||||
print("--------")
|
||||
except json.decoder.JSONDecodeError:
|
||||
warnings.warn("unable to parse version details from package URL.")
|
||||
except KeyError:
|
||||
warnings.warn("package URL does not contain version info.")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_local_ip_address() -> str:
|
||||
"""
|
||||
Gets the public IP address or returns the string "No internet connection" if unable
|
||||
to obtain it or the string "Analytics disabled" if a user has disabled analytics.
|
||||
Does not make a new request if the IP address has already been obtained in the
|
||||
same Python session.
|
||||
"""
|
||||
if not analytics_enabled():
|
||||
return "Analytics disabled"
|
||||
|
||||
if Context.ip_address is None:
|
||||
try:
|
||||
ip_address = requests.get(
|
||||
"https://checkip.amazonaws.com/", timeout=3
|
||||
).text.strip()
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
ip_address = "No internet connection"
|
||||
Context.ip_address = ip_address
|
||||
else:
|
||||
ip_address = Context.ip_address
|
||||
return ip_address
|
||||
|
||||
|
||||
def initiated_analytics(data: dict[str, Any]) -> None:
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-initiated-analytics/",
|
||||
"data": {**data, "ip_address": get_local_ip_address()},
|
||||
},
|
||||
).start()
|
||||
|
||||
|
||||
def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
blocks_telemetry, inputs_telemetry, outputs_telemetry, targets_telemetry = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
from gradio.blocks import BlockContext
|
||||
|
||||
for x in list(blocks.blocks.values()):
|
||||
blocks_telemetry.append(x.get_block_name()) if isinstance(
|
||||
x, BlockContext
|
||||
) else blocks_telemetry.append(str(x))
|
||||
|
||||
for x in blocks.dependencies:
|
||||
targets_telemetry = targets_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["targets"]
|
||||
]
|
||||
inputs_telemetry = inputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["inputs"]
|
||||
]
|
||||
outputs_telemetry = outputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["outputs"]
|
||||
]
|
||||
additional_data = {
|
||||
"version": GRADIO_VERSION,
|
||||
"is_kaggle": blocks.is_kaggle,
|
||||
"is_sagemaker": blocks.is_sagemaker,
|
||||
"using_auth": blocks.auth is not None,
|
||||
"dev_mode": blocks.dev_mode,
|
||||
"show_api": blocks.show_api,
|
||||
"show_error": blocks.show_error,
|
||||
"title": blocks.title,
|
||||
"inputs": blocks.input_components
|
||||
if blocks.mode == "interface"
|
||||
else inputs_telemetry,
|
||||
"outputs": blocks.output_components
|
||||
if blocks.mode == "interface"
|
||||
else outputs_telemetry,
|
||||
"targets": targets_telemetry,
|
||||
"blocks": blocks_telemetry,
|
||||
"events": [str(x["trigger"]) for x in blocks.dependencies],
|
||||
}
|
||||
|
||||
data.update(additional_data)
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-launched-telemetry/",
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
||||
|
||||
|
||||
def integration_analytics(data: dict[str, Any]) -> None:
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-integration-analytics/",
|
||||
"data": {**data, "ip_address": get_local_ip_address()},
|
||||
},
|
||||
).start()
|
||||
|
||||
|
||||
def error_analytics(message: str) -> None:
|
||||
"""
|
||||
Send error analytics if there is network
|
||||
Parameters:
|
||||
message: Details about error
|
||||
"""
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
data = {"ip_address": get_local_ip_address(), "error": message}
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-error-analytics/",
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
@ -24,6 +24,7 @@ from packaging import version
|
||||
from typing_extensions import Literal
|
||||
|
||||
from gradio import (
|
||||
analytics,
|
||||
components,
|
||||
external,
|
||||
networking,
|
||||
@ -686,7 +687,7 @@ class Blocks(BlockContext):
|
||||
self.analytics_enabled = (
|
||||
analytics_enabled
|
||||
if analytics_enabled is not None
|
||||
else os.getenv("GRADIO_ANALYTICS_ENABLED", "True") == "True"
|
||||
else analytics.analytics_enabled()
|
||||
)
|
||||
if not self.analytics_enabled:
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
|
||||
@ -737,7 +738,7 @@ class Blocks(BlockContext):
|
||||
"is_custom_theme": is_custom_theme,
|
||||
"version": GRADIO_VERSION,
|
||||
}
|
||||
utils.initiated_analytics(data)
|
||||
analytics.initiated_analytics(data)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
@ -1835,7 +1836,7 @@ Received outputs:
|
||||
print(strings.en["SHARE_LINK_MESSAGE"])
|
||||
except (RuntimeError, requests.exceptions.ConnectionError):
|
||||
if self.analytics_enabled:
|
||||
utils.error_analytics("Not able to set up tunnel")
|
||||
analytics.error_analytics("Not able to set up tunnel")
|
||||
self.share_url = None
|
||||
self.share = False
|
||||
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
|
||||
@ -1925,8 +1926,7 @@ Received outputs:
|
||||
"is_spaces": self.is_space,
|
||||
"mode": self.mode,
|
||||
}
|
||||
utils.launch_analytics(data)
|
||||
utils.launched_telemetry(self, data)
|
||||
analytics.launched_analytics(self, data)
|
||||
|
||||
utils.show_tip(self)
|
||||
|
||||
@ -1995,7 +1995,7 @@ Received outputs:
|
||||
mlflow.log_param("Gradio Interface Local Link", self.local_url)
|
||||
if self.analytics_enabled and analytics_integration:
|
||||
data = {"integration": analytics_integration}
|
||||
utils.integration_analytics(data)
|
||||
analytics.integration_analytics(data)
|
||||
|
||||
def close(self, verbose: bool = True) -> None:
|
||||
"""
|
||||
|
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio import Examples, external, interpretation, utils
|
||||
from gradio import Examples, analytics, external, interpretation, utils
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.components import (
|
||||
Button,
|
||||
@ -362,7 +362,7 @@ class Interface(Blocks):
|
||||
self.local_url = None
|
||||
|
||||
self.favicon_path = None
|
||||
utils.version_check()
|
||||
analytics.version_check()
|
||||
Interface.instances.add(self)
|
||||
|
||||
param_names = inspect.getfullargspec(self.fn)[0]
|
||||
|
175
gradio/utils.py
175
gradio/utils.py
@ -13,12 +13,10 @@ import pkgutil
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import StrictVersion
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from numbers import Number
|
||||
@ -32,7 +30,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
import httpx
|
||||
import matplotlib
|
||||
@ -54,8 +51,6 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio.blocks import Block, BlockContext
|
||||
from gradio.components import Component
|
||||
|
||||
analytics_url = "https://api.gradio.app/"
|
||||
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
|
||||
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
|
||||
GRADIO_VERSION = (
|
||||
(pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()
|
||||
@ -64,176 +59,6 @@ GRADIO_VERSION = (
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def version_check():
|
||||
try:
|
||||
version_data = pkgutil.get_data(__name__, "version.txt")
|
||||
if not version_data:
|
||||
raise FileNotFoundError
|
||||
current_pkg_version = version_data.decode("ascii").strip()
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL, timeout=3).json()[
|
||||
"version"
|
||||
]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print(
|
||||
f"IMPORTANT: You are using gradio version {current_pkg_version}, "
|
||||
f"however version {latest_pkg_version} is available, please upgrade."
|
||||
)
|
||||
print("--------")
|
||||
except json.decoder.JSONDecodeError:
|
||||
warnings.warn("unable to parse version details from package URL.")
|
||||
except KeyError:
|
||||
warnings.warn("package URL does not contain version info.")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_local_ip_address() -> str:
|
||||
"""Gets the public IP address or returns the string "No internet connection" if unable to obtain it. Does not make a new request if the IP address has already been obtained."""
|
||||
if Context.ip_address is None:
|
||||
try:
|
||||
ip_address = requests.get(
|
||||
"https://checkip.amazonaws.com/", timeout=3
|
||||
).text.strip()
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
ip_address = "No internet connection"
|
||||
Context.ip_address = ip_address
|
||||
else:
|
||||
ip_address = Context.ip_address
|
||||
return ip_address
|
||||
|
||||
|
||||
def initiated_analytics(data: dict[str, Any]) -> None:
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def initiated_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-initiated-analytics/", data=data, timeout=5
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
threading.Thread(target=initiated_analytics_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def launch_analytics(data: dict[str, Any]) -> None:
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def launch_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-launched-analytics/", data=data, timeout=5
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
threading.Thread(target=launch_analytics_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def launched_telemetry(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
blocks_telemetry, inputs_telemetry, outputs_telemetry, targets_telemetry = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
from gradio.blocks import BlockContext
|
||||
|
||||
for x in list(blocks.blocks.values()):
|
||||
blocks_telemetry.append(x.get_block_name()) if isinstance(
|
||||
x, BlockContext
|
||||
) else blocks_telemetry.append(str(x))
|
||||
|
||||
for x in blocks.dependencies:
|
||||
targets_telemetry = targets_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["targets"]
|
||||
]
|
||||
inputs_telemetry = inputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["inputs"]
|
||||
]
|
||||
outputs_telemetry = outputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["outputs"]
|
||||
]
|
||||
additional_data = {
|
||||
"version": GRADIO_VERSION,
|
||||
"is_kaggle": blocks.is_kaggle,
|
||||
"is_sagemaker": blocks.is_sagemaker,
|
||||
"using_auth": blocks.auth is not None,
|
||||
"dev_mode": blocks.dev_mode,
|
||||
"show_api": blocks.show_api,
|
||||
"show_error": blocks.show_error,
|
||||
"title": blocks.title,
|
||||
"inputs": blocks.input_components
|
||||
if blocks.mode == "interface"
|
||||
else inputs_telemetry,
|
||||
"outputs": blocks.output_components
|
||||
if blocks.mode == "interface"
|
||||
else outputs_telemetry,
|
||||
"targets": targets_telemetry,
|
||||
"blocks": blocks_telemetry,
|
||||
"events": [str(x["trigger"]) for x in blocks.dependencies],
|
||||
}
|
||||
|
||||
data.update(additional_data)
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def launched_telemtry_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-launched-telemetry/", data=data, timeout=5
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
threading.Thread(target=launched_telemtry_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def integration_analytics(data: dict[str, Any]) -> None:
|
||||
data.update({"ip_address": get_local_ip_address()})
|
||||
|
||||
def integration_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-integration-analytics/", data=data, timeout=5
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
threading.Thread(target=integration_analytics_thread, args=(data,)).start()
|
||||
|
||||
|
||||
def error_analytics(message: str) -> None:
|
||||
"""
|
||||
Send error analytics if there is network
|
||||
Parameters:
|
||||
message: Details about error
|
||||
"""
|
||||
data = {"ip_address": get_local_ip_address(), "error": message}
|
||||
|
||||
def error_analytics_thread(data: dict[str, Any]) -> None:
|
||||
try:
|
||||
requests.post(
|
||||
f"{analytics_url}gradio-error-analytics/", data=data, timeout=5
|
||||
)
|
||||
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
threading.Thread(target=error_analytics_thread, args=(data,)).start()
|
||||
|
||||
|
||||
async def log_feature_analytics(feature: str) -> None:
|
||||
data = {"ip_address": get_local_ip_address(), "feature": feature}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(
|
||||
f"{analytics_url}gradio-feature-analytics/", data=data
|
||||
):
|
||||
pass
|
||||
except (aiohttp.ClientError):
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def colab_check() -> bool:
|
||||
"""
|
||||
Check if interface is launching from Google Colab
|
||||
|
76
test/test_analytics.py
Normal file
76
test/test_analytics.py
Normal file
@ -0,0 +1,76 @@
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from unittest import mock as mock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from gradio import analytics
|
||||
from gradio.context import Context
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestAnalytics:
|
||||
@mock.patch("requests.get")
|
||||
def test_should_warn_with_unable_to_parse(self, mock_get, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
mock_get.side_effect = json.decoder.JSONDecodeError("Expecting value", "", 0)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
analytics.version_check()
|
||||
assert (
|
||||
str(w[-1].message)
|
||||
== "unable to parse version details from package URL."
|
||||
)
|
||||
|
||||
@mock.patch("requests.Response.json")
|
||||
def test_should_warn_url_not_having_version(self, mock_json, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
mock_json.return_value = {"foo": "bar"}
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
analytics.version_check()
|
||||
assert str(w[-1].message) == "package URL does not contain version info."
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(
|
||||
self, mock_post, monkeypatch
|
||||
):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
analytics.error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_successful(self, mock_post, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
analytics.error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
|
||||
class TestIPAddress:
|
||||
@pytest.mark.flaky
|
||||
def test_get_ip(self):
|
||||
Context.ip_address = None
|
||||
ip = analytics.get_local_ip_address()
|
||||
if ip == "No internet connection" or ip == "Analytics disabled":
|
||||
return
|
||||
ipaddress.ip_address(ip)
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_get_ip_without_internet(self, mock_get, monkeypatch):
|
||||
mock_get.side_effect = requests.ConnectionError()
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
Context.ip_address = None
|
||||
ip = analytics.get_local_ip_address()
|
||||
assert ip == "No internet connection"
|
||||
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "False")
|
||||
Context.ip_address = None
|
||||
ip = analytics.get_local_ip_address()
|
||||
assert ip == "Analytics disabled"
|
@ -198,7 +198,8 @@ class TestBlocksMethods:
|
||||
assert result
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_initiated_analytics(self, mock_post):
|
||||
def test_initiated_analytics(self, mock_post, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
with gr.Blocks(analytics_enabled=True):
|
||||
pass
|
||||
mock_post.assert_called_once()
|
||||
|
@ -1,6 +1,4 @@
|
||||
import copy
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest.mock as mock
|
||||
@ -16,7 +14,6 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from gradio import EventData, Request
|
||||
from gradio.context import Context
|
||||
from gradio.test_data.blocks_configs import (
|
||||
XRAY_CONFIG,
|
||||
XRAY_CONFIG_DIFF_IDS,
|
||||
@ -30,65 +27,23 @@ from gradio.utils import (
|
||||
check_function_inputs_match,
|
||||
colab_check,
|
||||
delete_none,
|
||||
error_analytics,
|
||||
format_ner_list,
|
||||
get_local_ip_address,
|
||||
get_type_hints,
|
||||
ipython_check,
|
||||
is_special_typed_parameter,
|
||||
kaggle_check,
|
||||
launch_analytics,
|
||||
readme_to_html,
|
||||
sagemaker_check,
|
||||
sanitize_list_for_csv,
|
||||
sanitize_value_for_csv,
|
||||
tex2svg,
|
||||
validate_url,
|
||||
version_check,
|
||||
)
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestUtils:
|
||||
@mock.patch("requests.get")
|
||||
def test_should_warn_with_unable_to_parse(self, mock_get):
|
||||
mock_get.side_effect = json.decoder.JSONDecodeError("Expecting value", "", 0)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
version_check()
|
||||
assert (
|
||||
str(w[-1].message)
|
||||
== "unable to parse version details from package URL."
|
||||
)
|
||||
|
||||
@mock.patch("requests.Response.json")
|
||||
def test_should_warn_url_not_having_version(self, mock_json):
|
||||
mock_json.return_value = {"foo": "bar"}
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
version_check()
|
||||
assert str(w[-1].message) == "package URL does not contain version info."
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_error_analytics_successful(self, mock_post):
|
||||
error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
def test_launch_analytics_doesnt_crash_on_connection_error(self, mock_post):
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
launch_analytics(data={})
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("IPython.get_ipython")
|
||||
def test_colab_check_no_ipython(self, mock_get_ipython):
|
||||
mock_get_ipython.return_value = None
|
||||
@ -154,23 +109,6 @@ class TestUtils:
|
||||
assert not kaggle_check()
|
||||
|
||||
|
||||
class TestIPAddress:
|
||||
@pytest.mark.flaky
|
||||
def test_get_ip(self):
|
||||
Context.ip_address = None
|
||||
ip = get_local_ip_address()
|
||||
if ip == "No internet connection":
|
||||
return
|
||||
ipaddress.ip_address(ip)
|
||||
|
||||
@mock.patch("requests.get")
|
||||
def test_get_ip_without_internet(self, mock_get):
|
||||
Context.ip_address = None
|
||||
mock_get.side_effect = requests.ConnectionError()
|
||||
ip = get_local_ip_address()
|
||||
assert ip == "No internet connection"
|
||||
|
||||
|
||||
class TestAssertConfigsEquivalent:
|
||||
def test_same_configs(self):
|
||||
assert assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, XRAY_CONFIG)
|
||||
@ -402,7 +340,6 @@ async def test_get(respx_mock):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post(respx_mock):
|
||||
|
||||
payload = {"name": "morpheus", "job": "leader"}
|
||||
respx_mock.post(MOCK_REQUEST_URL).mock(make_mock_response(payload))
|
||||
|
||||
@ -419,7 +356,6 @@ async def test_post(respx_mock):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_model(respx_mock):
|
||||
|
||||
response = make_mock_response(
|
||||
{
|
||||
"name": "morpheus",
|
||||
@ -486,7 +422,6 @@ async def test_exception_type(validate_response_data, respx_mock):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_with_function(respx_mock):
|
||||
|
||||
respx_mock.post(MOCK_REQUEST_URL).mock(
|
||||
make_mock_response({"name": "morpheus", "id": 1})
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user