Lite: Fix the analytics module to use asyncio to work in the Wasm env (#5045)

* Fix the analytics module to use asyncio to work in the Wasm env

* add changeset

* add changeset

* Add changeset

* Revert "Add changeset"

This reverts commit 052f2bd737.

* Add the `is_wasm` field to the analytics telemetry

* Restore the initialization code in Blocks.launch() in the Wasm env

* Remove a call of analytics.version_check() in Blocks.launch() in the Wasm mode

* Add a `type: ignore` directive to the line of importing the `pyodide` module

* Fix a test case

* Refactor _do_wasm_analytics_request() and add a unit test for it

* Get the IP address for analytics in the Wasm mode

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Yuichiro Tachibana (Tsuchiya) 2023-08-04 22:18:14 +09:00 committed by GitHub
parent 13e4783535
commit 3b9494f5c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 44 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Lite: Fix the analytics module to use asyncio to work in the Wasm env

View File

@ -1,10 +1,12 @@
""" Functions related to analytics and telemetry. """
from __future__ import annotations
import asyncio
import json
import os
import pkgutil
import threading
import urllib.parse
import warnings
from distutils.version import StrictVersion
from typing import Any
@ -12,9 +14,21 @@ from typing import Any
import requests
import gradio
from gradio import wasm_utils
from gradio.context import Context
from gradio.utils import GRADIO_VERSION
# For testability, we import the pyfetch function into this module scope and define a fallback coroutine object to be patched in tests.
try:
from pyodide.http import pyfetch as pyodide_pyfetch # type: ignore
except ImportError:
async def pyodide_pyfetch(*args, **kwargs):
raise NotImplementedError(
"pyodide.http.pyfetch is not available in this environment."
)
ANALYTICS_URL = "https://api.gradio.app/"
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
@ -27,6 +41,24 @@ def analytics_enabled() -> bool:
def _do_analytics_request(url: str, data: dict[str, Any]) -> None:
if wasm_utils.IS_WASM:
asyncio.ensure_future(
_do_wasm_analytics_request(
url=url,
data=data,
)
)
else:
threading.Thread(
target=_do_normal_analytics_request,
kwargs={
"url": url,
"data": data,
},
).start()
def _do_normal_analytics_request(url: str, data: dict[str, Any]) -> None:
data["ip_address"] = get_local_ip_address()
try:
requests.post(url, data=data, timeout=5)
@ -34,6 +66,25 @@ def _do_analytics_request(url: str, data: dict[str, Any]) -> None:
pass # do not push analytics if no network
async def _do_wasm_analytics_request(url: str, data: dict[str, Any]) -> None:
data["ip_address"] = await get_local_ip_address_wasm()
# We use urllib.parse.urlencode to encode the data as a form.
# Ref: https://docs.python.org/3/library/urllib.request.html#urllib-examples
body = urllib.parse.urlencode(data).encode("ascii")
headers = {
"Content-Type": "application/x-www-form-urlencoded",
}
try:
await asyncio.wait_for(
pyodide_pyfetch(url, method="POST", headers=headers, body=body),
timeout=5,
)
except asyncio.TimeoutError:
pass # do not push analytics if no network
def version_check():
try:
version_data = pkgutil.get_data(__name__, "version.txt")
@ -80,17 +131,38 @@ def get_local_ip_address() -> str:
return ip_address
async def get_local_ip_address_wasm() -> str:
"""The Wasm-compatible version of get_local_ip_address()."""
if not analytics_enabled():
return "Analytics disabled"
if Context.ip_address is None:
try:
response = await asyncio.wait_for(
pyodide_pyfetch(
# The API used by the normal version (`get_local_ip_address()`), `https://checkip.amazonaws.com/``, blocks CORS requests, so here we use a different API.
"https://api.ipify.org"
),
timeout=5,
)
response_text: str = await response.string() # type: ignore
ip_address = response_text.strip()
except (asyncio.TimeoutError, OSError):
ip_address = "No internet connection"
Context.ip_address = ip_address
else:
ip_address = Context.ip_address
return ip_address
def initiated_analytics(data: dict[str, Any]) -> None:
if not analytics_enabled():
return
threading.Thread(
target=_do_analytics_request,
kwargs={
"url": f"{ANALYTICS_URL}gradio-initiated-analytics/",
"data": data,
},
).start()
_do_analytics_request(
url=f"{ANALYTICS_URL}gradio-initiated-analytics/",
data=data,
)
def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
@ -142,30 +214,22 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
"targets": targets_telemetry,
"blocks": blocks_telemetry,
"events": [str(x["trigger"]) for x in blocks.dependencies],
"is_wasm": wasm_utils.IS_WASM,
}
data.update(additional_data)
threading.Thread(
target=_do_analytics_request,
kwargs={
"url": f"{ANALYTICS_URL}gradio-launched-telemetry/",
"data": data,
},
).start()
_do_analytics_request(url=f"{ANALYTICS_URL}gradio-launched-telemetry/", data=data)
def integration_analytics(data: dict[str, Any]) -> None:
if not analytics_enabled():
return
threading.Thread(
target=_do_analytics_request,
kwargs={
"url": f"{ANALYTICS_URL}gradio-integration-analytics/",
"data": data,
},
).start()
_do_analytics_request(
url=f"{ANALYTICS_URL}gradio-integration-analytics/",
data=data,
)
def error_analytics(message: str) -> None:
@ -179,10 +243,7 @@ def error_analytics(message: str) -> None:
data = {"error": message}
threading.Thread(
target=_do_analytics_request,
kwargs={
"url": f"{ANALYTICS_URL}gradio-error-analytics/",
"data": data,
},
).start()
_do_analytics_request(
url=f"{ANALYTICS_URL}gradio-error-analytics/",
data=data,
)

View File

@ -721,8 +721,9 @@ class Blocks(BlockContext):
else analytics.analytics_enabled()
)
if self.analytics_enabled:
t = threading.Thread(target=analytics.version_check)
t.start()
if not wasm_utils.IS_WASM:
t = threading.Thread(target=analytics.version_check)
t.start()
else:
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "True"
super().__init__(render=False, **kwargs)
@ -761,7 +762,7 @@ class Blocks(BlockContext):
self.root_path = os.environ.get("GRADIO_ROOT_PATH", "")
self.root_urls = set()
if not wasm_utils.IS_WASM and self.analytics_enabled:
if self.analytics_enabled:
is_custom_theme = not any(
self.theme.to_dict() == built_in_theme.to_dict()
for built_in_theme in BUILT_IN_THEMES.values()
@ -1866,7 +1867,7 @@ Received outputs:
if self.local_url.startswith("https") or self.is_colab
else "http"
)
if not self.is_colab:
if not wasm_utils.IS_WASM and not self.is_colab:
print(
strings.en["RUNNING_LOCALLY_SEPARATED"].format(
self.protocol, self.server_name, self.server_port
@ -1876,13 +1877,13 @@ Received outputs:
if self.enable_queue:
self._queue.set_url(self.local_url)
# Cannot run async functions in background other than app's scope.
# Workaround by triggering the app endpoint
if not wasm_utils.IS_WASM:
# Cannot run async functions in background other than app's scope.
# Workaround by triggering the app endpoint
requests.get(f"{self.local_url}startup-events", verify=ssl_verify)
if wasm_utils.IS_WASM:
return TupleNoPrint((self.server_app, self.local_url, self.share_url))
else:
pass
# TODO: Call the startup endpoint in the Wasm env too.
utils.launch_counter()
self.is_sagemaker = utils.sagemaker_check()
@ -1912,7 +1913,12 @@ Received outputs:
# If running in a colab or not able to access localhost,
# a shareable link must be created.
if _frontend and (not networking.url_ok(self.local_url)) and (not self.share):
if (
_frontend
and not wasm_utils.IS_WASM
and not networking.url_ok(self.local_url)
and not self.share
):
raise ValueError(
"When localhost is not accessible, a shareable link must be created. Please set share=True or check your proxy settings to allow access to localhost."
)
@ -1933,6 +1939,8 @@ Received outputs:
if self.share:
if self.space_id:
raise RuntimeError("Share is not supported when you are in Spaces")
if wasm_utils.IS_WASM:
raise RuntimeError("Share is not supported in the Wasm environment")
try:
if self.share_url is None:
self.share_url = networking.setup_tunnel(
@ -1958,11 +1966,11 @@ Received outputs:
)
)
else:
if not (quiet):
if not quiet and not wasm_utils.IS_WASM:
print(strings.en["PUBLIC_SHARE_TRUE"])
self.share_url = None
if inbrowser:
if inbrowser and not wasm_utils.IS_WASM:
link = self.share_url if self.share and self.share_url else self.local_url
webbrowser.open(link)
@ -2043,12 +2051,18 @@ Received outputs:
utils.show_tip(self)
# Block main thread if debug==True
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1 and not wasm_utils.IS_WASM:
self.block_thread()
# Block main thread if running in a script to stop script from exiting
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
if not prevent_thread_lock and not is_in_interactive_mode:
if (
not prevent_thread_lock
and not is_in_interactive_mode
# In the Wasm env, we don't have to block the main thread because the server won't be shut down after the execution finishes.
# Moreover, we MUST NOT do it because there is only one thread in the Wasm env and blocking it will stop the subsequent code from running.
and not wasm_utils.IS_WASM
):
self.block_thread()
return TupleNoPrint((self.server_app, self.local_url, self.share_url))

View File

@ -1,3 +1,4 @@
import asyncio
import ipaddress
import json
import os
@ -7,7 +8,7 @@ from unittest import mock as mock
import pytest
import requests
from gradio import analytics
from gradio import analytics, wasm_utils
from gradio.context import Context
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -33,7 +34,7 @@ class TestAnalytics:
):
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
mock_post.side_effect = requests.ConnectionError()
analytics._do_analytics_request("placeholder", {})
analytics._do_normal_analytics_request("placeholder", {})
mock_post.assert_called()
@mock.patch("requests.post")
@ -42,6 +43,26 @@ class TestAnalytics:
analytics.error_analytics("placeholder")
mock_post.assert_called()
@mock.patch.object(wasm_utils, "IS_WASM", True)
@mock.patch("gradio.analytics.pyodide_pyfetch")
@pytest.mark.asyncio
async def test_error_analytics_successful_in_wasm_mode(
self, pyodide_pyfetch, monkeypatch
):
loop = asyncio.get_event_loop()
monkeypatch.setenv("GRADIO_ANALYTICS_ENABLED", "True")
analytics.error_analytics("placeholder")
# Await all background tasks.
# Ref: https://superfastpython.com/asyncio-wait-for-tasks/#How_to_Wait_for_All_Background_Tasks
all_tasks = asyncio.all_tasks(loop)
current_task = asyncio.current_task()
all_tasks.remove(current_task)
await asyncio.wait(all_tasks)
pyodide_pyfetch.assert_called()
class TestIPAddress:
@pytest.mark.flaky