mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Lite: Fix the analytics module to use asyncio to work in the Wasm env (#5045)
* Fix the analytics module to use asyncio to work in the Wasm env
* add changeset
* add changeset
* Add changeset
* Revert "Add changeset"
This reverts commit 052f2bd737
.
* Add the `is_wasm` field to the analytics telemetry
* Restore the initialization code in Blocks.launch() in the Wasm env
* Remove a call of analytics.version_check() in Blocks.launch() in the Wasm mode
* Add a `type: ignore` directive to the line of importing the `pyodide` module
* Fix a test case
* Refactor _do_wasm_analytics_request() and add a unit test for it
* Get the IP address for analytics in the Wasm mode
---------
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
13e4783535
commit
3b9494f5c5
5
.changeset/hungry-bobcats-own.md
Normal file
5
.changeset/hungry-bobcats-own.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Lite: Fix the analytics module to use asyncio to work in the Wasm env
|
@ -1,10 +1,12 @@
|
||||
""" Functions related to analytics and telemetry. """
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import pkgutil
|
||||
import threading
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from distutils.version import StrictVersion
|
||||
from typing import Any
|
||||
@ -12,9 +14,21 @@ from typing import Any
|
||||
import requests
|
||||
|
||||
import gradio
|
||||
from gradio import wasm_utils
|
||||
from gradio.context import Context
|
||||
from gradio.utils import GRADIO_VERSION
|
||||
|
||||
# For testability, we import the pyfetch function into this module scope and define a fallback coroutine object to be patched in tests.
|
||||
try:
|
||||
from pyodide.http import pyfetch as pyodide_pyfetch # type: ignore
|
||||
except ImportError:
|
||||
|
||||
async def pyodide_pyfetch(*args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"pyodide.http.pyfetch is not available in this environment."
|
||||
)
|
||||
|
||||
|
||||
ANALYTICS_URL = "https://api.gradio.app/"
|
||||
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
|
||||
|
||||
@ -27,6 +41,24 @@ def analytics_enabled() -> bool:
|
||||
|
||||
|
||||
def _do_analytics_request(url: str, data: dict[str, Any]) -> None:
|
||||
if wasm_utils.IS_WASM:
|
||||
asyncio.ensure_future(
|
||||
_do_wasm_analytics_request(
|
||||
url=url,
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
else:
|
||||
threading.Thread(
|
||||
target=_do_normal_analytics_request,
|
||||
kwargs={
|
||||
"url": url,
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
||||
|
||||
|
||||
def _do_normal_analytics_request(url: str, data: dict[str, Any]) -> None:
|
||||
data["ip_address"] = get_local_ip_address()
|
||||
try:
|
||||
requests.post(url, data=data, timeout=5)
|
||||
@ -34,6 +66,25 @@ def _do_analytics_request(url: str, data: dict[str, Any]) -> None:
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
async def _do_wasm_analytics_request(url: str, data: dict[str, Any]) -> None:
|
||||
data["ip_address"] = await get_local_ip_address_wasm()
|
||||
|
||||
# We use urllib.parse.urlencode to encode the data as a form.
|
||||
# Ref: https://docs.python.org/3/library/urllib.request.html#urllib-examples
|
||||
body = urllib.parse.urlencode(data).encode("ascii")
|
||||
headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
pyodide_pyfetch(url, method="POST", headers=headers, body=body),
|
||||
timeout=5,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass # do not push analytics if no network
|
||||
|
||||
|
||||
def version_check():
|
||||
try:
|
||||
version_data = pkgutil.get_data(__name__, "version.txt")
|
||||
@ -80,17 +131,38 @@ def get_local_ip_address() -> str:
|
||||
return ip_address
|
||||
|
||||
|
||||
async def get_local_ip_address_wasm() -> str:
|
||||
"""The Wasm-compatible version of get_local_ip_address()."""
|
||||
if not analytics_enabled():
|
||||
return "Analytics disabled"
|
||||
|
||||
if Context.ip_address is None:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
pyodide_pyfetch(
|
||||
# The API used by the normal version (`get_local_ip_address()`), `https://checkip.amazonaws.com/``, blocks CORS requests, so here we use a different API.
|
||||
"https://api.ipify.org"
|
||||
),
|
||||
timeout=5,
|
||||
)
|
||||
response_text: str = await response.string() # type: ignore
|
||||
ip_address = response_text.strip()
|
||||
except (asyncio.TimeoutError, OSError):
|
||||
ip_address = "No internet connection"
|
||||
Context.ip_address = ip_address
|
||||
else:
|
||||
ip_address = Context.ip_address
|
||||
return ip_address
|
||||
|
||||
|
||||
def initiated_analytics(data: dict[str, Any]) -> None:
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-initiated-analytics/",
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
||||
_do_analytics_request(
|
||||
url=f"{ANALYTICS_URL}gradio-initiated-analytics/",
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
@ -142,30 +214,22 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
"targets": targets_telemetry,
|
||||
"blocks": blocks_telemetry,
|
||||
"events": [str(x["trigger"]) for x in blocks.dependencies],
|
||||
"is_wasm": wasm_utils.IS_WASM,
|
||||
}
|
||||
|
||||
data.update(additional_data)
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-launched-telemetry/",
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
||||
_do_analytics_request(url=f"{ANALYTICS_URL}gradio-launched-telemetry/", data=data)
|
||||
|
||||
|
||||
def integration_analytics(data: dict[str, Any]) -> None:
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-integration-analytics/",
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
||||
_do_analytics_request(
|
||||
url=f"{ANALYTICS_URL}gradio-integration-analytics/",
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def error_analytics(message: str) -> None:
|
||||
@ -179,10 +243,7 @@ def error_analytics(message: str) -> None:
|
||||
|
||||
data = {"error": message}
|
||||
|
||||
threading.Thread(
|
||||
target=_do_analytics_request,
|
||||
kwargs={
|
||||
"url": f"{ANALYTICS_URL}gradio-error-analytics/",
|
||||
"data": data,
|
||||
},
|
||||
).start()
|
||||
_do_analytics_request(
|
||||
url=f"{ANALYTICS_URL}gradio-error-analytics/",
|
||||
data=data,
|
||||
)
|
||||
|
@ -721,8 +721,9 @@ class Blocks(BlockContext):
|
||||
else analytics.analytics_enabled()
|
||||
)
|
||||
if self.analytics_enabled:
|
||||
t = threading.Thread(target=analytics.version_check)
|
||||
t.start()
|
||||
if not wasm_utils.IS_WASM:
|
||||
t = threading.Thread(target=analytics.version_check)
|
||||
t.start()
|
||||
else:
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
|
||||
super().__init__(render=False, **kwargs)
|
||||
@ -761,7 +762,7 @@ class Blocks(BlockContext):
|
||||
self.root_path = os.environ.get("GRADIO_ROOT_PATH", "")
|
||||
self.root_urls = set()
|
||||
|
||||
if not wasm_utils.IS_WASM and self.analytics_enabled:
|
||||
if self.analytics_enabled:
|
||||
is_custom_theme = not any(
|
||||
self.theme.to_dict() == built_in_theme.to_dict()
|
||||
for built_in_theme in BUILT_IN_THEMES.values()
|
||||
@ -1866,7 +1867,7 @@ Received outputs:
|
||||
if self.local_url.startswith("https") or self.is_colab
|
||||
else "http"
|
||||
)
|
||||
if not self.is_colab:
|
||||
if not wasm_utils.IS_WASM and not self.is_colab:
|
||||
print(
|
||||
strings.en["RUNNING_LOCALLY_SEPARATED"].format(
|
||||
self.protocol, self.server_name, self.server_port
|
||||
@ -1876,13 +1877,13 @@ Received outputs:
|
||||
if self.enable_queue:
|
||||
self._queue.set_url(self.local_url)
|
||||
|
||||
# Cannot run async functions in background other than app's scope.
|
||||
# Workaround by triggering the app endpoint
|
||||
if not wasm_utils.IS_WASM:
|
||||
# Cannot run async functions in background other than app's scope.
|
||||
# Workaround by triggering the app endpoint
|
||||
requests.get(f"{self.local_url}startup-events", verify=ssl_verify)
|
||||
|
||||
if wasm_utils.IS_WASM:
|
||||
return TupleNoPrint((self.server_app, self.local_url, self.share_url))
|
||||
else:
|
||||
pass
|
||||
# TODO: Call the startup endpoint in the Wasm env too.
|
||||
|
||||
utils.launch_counter()
|
||||
self.is_sagemaker = utils.sagemaker_check()
|
||||
@ -1912,7 +1913,12 @@ Received outputs:
|
||||
|
||||
# If running in a colab or not able to access localhost,
|
||||
# a shareable link must be created.
|
||||
if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
|
||||
if (
|
||||
_frontend
|
||||
and not wasm_utils.IS_WASM
|
||||
and not networking.url_ok(self.local_url)
|
||||
and not self.share
|
||||
):
|
||||
raise ValueError(
|
||||
"When localhost is not accessible, a shareable link must be created. Please set share=True or check your proxy settings to allow access to localhost."
|
||||
)
|
||||
@ -1933,6 +1939,8 @@ Received outputs:
|
||||
if self.share:
|
||||
if self.space_id:
|
||||
raise RuntimeError("Share is not supported when you are in Spaces")
|
||||
if wasm_utils.IS_WASM:
|
||||
raise RuntimeError("Share is not supported in the Wasm environment")
|
||||
try:
|
||||
if self.share_url is None:
|
||||
self.share_url = networking.setup_tunnel(
|
||||
@ -1958,11 +1966,11 @@ Received outputs:
|
||||
)
|
||||
)
|
||||
else:
|
||||
if not (quiet):
|
||||
if not quiet and not wasm_utils.IS_WASM:
|
||||
print(strings.en["PUBLIC_SHARE_TRUE"])
|
||||
self.share_url = None
|
||||
|
||||
if inbrowser:
|
||||
if inbrowser and not wasm_utils.IS_WASM:
|
||||
link = self.share_url if self.share and self.share_url else self.local_url
|
||||
webbrowser.open(link)
|
||||
|
||||
@ -2043,12 +2051,18 @@ Received outputs:
|
||||
utils.show_tip(self)
|
||||
|
||||
# Block main thread if debug==True
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1 and not wasm_utils.IS_WASM:
|
||||
self.block_thread()
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
|
||||
|
||||
if not prevent_thread_lock and not is_in_interactive_mode:
|
||||
if (
|
||||
not prevent_thread_lock
|
||||
and not is_in_interactive_mode
|
||||
# In the Wasm env, we don't have to block the main thread because the server won't be shut down after the execution finishes.
|
||||
# Moreover, we MUST NOT do it because there is only one thread in the Wasm env and blocking it will stop the subsequent code from running.
|
||||
and not wasm_utils.IS_WASM
|
||||
):
|
||||
self.block_thread()
|
||||
|
||||
return TupleNoPrint((self.server_app, self.local_url, self.share_url))
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
@ -7,7 +8,7 @@ from unittest import mock as mock
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from gradio import analytics
|
||||
from gradio import analytics, wasm_utils
|
||||
from gradio.context import Context
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -33,7 +34,7 @@ class TestAnalytics:
|
||||
):
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
mock_post.side_effect = requests.ConnectionError()
|
||||
analytics._do_analytics_request("placeholder", {})
|
||||
analytics._do_normal_analytics_request("placeholder", {})
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch("requests.post")
|
||||
@ -42,6 +43,26 @@ class TestAnalytics:
|
||||
analytics.error_analytics("placeholder")
|
||||
mock_post.assert_called()
|
||||
|
||||
@mock.patch.object(wasm_utils, "IS_WASM", True)
|
||||
@mock.patch("gradio.analytics.pyodide_pyfetch")
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_analytics_successful_in_wasm_mode(
|
||||
self, pyodide_pyfetch, monkeypatch
|
||||
):
|
||||
loop = asyncio.get_event_loop()
|
||||
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
|
||||
|
||||
analytics.error_analytics("placeholder")
|
||||
|
||||
# Await all background tasks.
|
||||
# Ref: https://superfastpython.com/asyncio-wait-for-tasks/#How_to_Wait_for_All_Background_Tasks
|
||||
all_tasks = asyncio.all_tasks(loop)
|
||||
current_task = asyncio.current_task()
|
||||
all_tasks.remove(current_task)
|
||||
await asyncio.wait(all_tasks)
|
||||
|
||||
pyodide_pyfetch.assert_called()
|
||||
|
||||
|
||||
class TestIPAddress:
|
||||
@pytest.mark.flaky
|
||||
|
Loading…
Reference in New Issue
Block a user