added feature-specific analytics

This commit is contained in:
Abubakar Abid 2020-11-24 12:46:13 -06:00
parent f8fa81efbb
commit 9221bad95f
2 changed files with 22 additions and 5 deletions

View File

@ -44,8 +44,7 @@ class Interface:
title=None, description=None, thumbnail=None,
server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
embedding="default",
flagging_dir="flagged", analytics_enabled=True):
embedding="default", flagging_dir="flagged", analytics_enabled=True):
"""
Parameters:
@ -135,7 +134,11 @@ class Interface:
'outputs': outputs,
'live': live,
'capture_session': capture_session,
'ip_address': ip_address
'ip_address': ip_address,
'interpretation': interpretation,
'embedding': embedding,
'allow_flagging': allow_flagging,
'allow_screenshot': allow_screenshot,
}
if self.capture_session:
@ -385,7 +388,7 @@ class Interface:
self.share = share
if share:
print("This share link will expire in 6 hours. If you need a "
print("This share link will expire in 24 hours. If you need a "
"permanent link, email support@gradio.app")
try:
share_url = networking.setup_tunnel(server_port)

View File

@ -30,6 +30,7 @@ TRY_NUM_PORTS = int(os.getenv(
LOCALHOST_NAME = os.getenv(
'GRADIO_SERVER_NAME', "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
GRADIO_FEATURE_ANALYTICS_URL = "https://api.gradio.app/gradio-feature-analytics/"
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
@ -67,6 +68,7 @@ def get_local_ip_address():
ip_address = "No internet connection"
return ip_address
IP_ADDRESS = get_local_ip_address()
def get_first_available_port(initial, final):
"""
@ -120,6 +122,15 @@ def predict():
output = {"data": prediction, "durations": durations}
return jsonify(output)
def log_feature_analytics(feature):
if app.interface.analytics_enabled:
try:
requests.post(GRADIO_FEATURE_ANALYTICS_URL,
data={
'ip_address': IP_ADDRESS,
'feature': feature})
except requests.ConnectionError:
pass # do not push analytics if no network
@app.route("/api/score_similarity/", methods=["POST"])
def score_similarity():
@ -135,7 +146,7 @@ def score_similarity():
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embed(preprocessed_example)
scores.append(calculate_similarity(input_embedding, example_embedding))
log_feature_analytics('score_similarity')
return jsonify({"data": scores})
@ -159,6 +170,7 @@ def view_embeddings():
sample_embedding_2d = embeddings_2d[:len(sample_embedding)]
example_embeddings_2d = embeddings_2d[len(sample_embedding):]
app.pca_model = pca_model
log_feature_analytics('view_embeddings')
return jsonify({"sample_embedding_2d": sample_embedding_2d, "example_embeddings_2d": example_embeddings_2d})
@ -194,6 +206,7 @@ def predict_examples():
@app.route("/api/flag/", methods=["POST"])
def flag():
log_feature_analytics('flag')
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
os.makedirs(flag_path,
exist_ok=True)
@ -230,6 +243,7 @@ def flag():
@app.route("/api/interpret/", methods=["POST"])
def interpret():
log_feature_analytics('interpret')
raw_input = request.json["data"]
interpretation_scores, alternative_outputs = app.interface.interpret(raw_input)
return jsonify({