starting async

This commit is contained in:
Abubakar Abid 2022-01-06 18:43:39 -05:00
parent 560b6d3b64
commit 9cb642abbb
3 changed files with 16 additions and 11 deletions

View File

@ -201,7 +201,7 @@ async def flag(
username: str = Depends(get_current_user)
):
if app.interface.analytics_enabled:
utils.log_feature_analytics(app.interface.ip_address, 'flag')
await utils.log_feature_analytics(app.interface.ip_address, 'flag')
body = await request.json()
data = body['data']
app.interface.flagging_callback.flag(
@ -214,7 +214,7 @@ async def flag(
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request):
if app.interface.analytics_enabled:
utils.log_feature_analytics(app.interface.ip_address, 'interpret')
await utils.log_feature_analytics(app.interface.ip_address, 'interpret')
body = await request.json()
raw_input = body["data"]
interpretation_scores, alternative_outputs = app.interface.interpret(

View File

@ -1,6 +1,7 @@
""" Handy utility functions."""
from __future__ import annotations
import aiohttp
import analytics
import csv
from distutils.version import StrictVersion
@ -11,7 +12,7 @@ import os
import pkg_resources
import random
import requests
from typing import Callable, Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Callable, Any, Dict, TYPE_CHECKING
import warnings
import gradio
@ -43,7 +44,7 @@ def version_check():
except KeyError:
warnings.warn("package URL does not contain version info.")
except:
warnings.warn("unable to connect with package URL to collect version info.")
pass
def get_local_ip_address() -> str:
@ -92,13 +93,16 @@ def error_analytics(type: RuntimeError | NameError) -> None:
pass # do not push analytics if no network
def log_feature_analytics(ip_address: str, feature: str) -> None:
async def log_feature_analytics(ip_address: str, feature: str) -> None:
data={'ip_address': ip_address, 'feature': feature}
try:
requests.post(analytics_url + 'gradio-feature-analytics/',
data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
async with aiohttp.ClientSession() as session:
try:
async with session.post(
analytics_url + 'gradio-feature-analytics/',
data=data) as resp:
await resp
except (aiohttp.ClientError):
pass # do not push analytics if no network
def colab_check() -> bool:
@ -113,7 +117,7 @@ def colab_check() -> bool:
if "google.colab" in str(from_ipynb):
is_colab = True
except (ImportError, NameError):
error_analytics("NameError")
pass
return is_colab

View File

@ -16,6 +16,7 @@ setup(
keywords=['machine learning', 'visualization', 'reproducibility'],
install_requires=[
'analytics-python',
'aiohttp',
'fastapi',
'ffmpy',
'markdown2',