rewrote DefaultFlaggingHandler not to need app

This commit is contained in:
Abubakar Abid 2021-11-01 11:40:51 -05:00
parent 00eccfd371
commit 27a4b4e637
7 changed files with 66 additions and 56 deletions

View File

@ -1,4 +1,4 @@
Metadata-Version: 1.0
Metadata-Version: 2.1
Name: gradio
Version: 2.4.0
Summary: Python library for easily interacting with trained machine learning models
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid
Author-email: a12d@stanford.edu
License: Apache License 2.0
Description: UNKNOWN
Keywords: machine learning,visualization,reproducibility
Platform: UNKNOWN
License-File: LICENSE
UNKNOWN

View File

@ -1,15 +1,15 @@
Flask-Cors>=3.0.8
Flask-Login
Flask>=1.1.1
analytics-python
ffmpy
flask-cachebuster
markdown2
matplotlib
numpy
pandas
paramiko
pillow
pycryptodome
pydub
matplotlib
pandas
pillow
ffmpy
markdown2
pycryptodome
requests
paramiko
analytics-python
Flask>=1.1.1
Flask-Cors>=3.0.8
flask-cachebuster
Flask-Login

View File

@ -10,12 +10,12 @@ class FlaggingHandler():
"""
A class for defining the methods that any FlaggingHandler should have.
"""
def __init__(self, app, **kwargs):
def __init__(self, interface, **kwargs):
"""
Parameters:
app: Flask app running the interface (in gradio.networking)
interface: The interface object that the FlaggingHandler is being used with.
"""
self.app = app
self.interface = interface
self.kwargs = kwargs
def flag(self, input_data, output_data, flag_option=None, flag_index=None, username=None, path=None):
@ -34,20 +34,18 @@ class FlaggingHandler():
class DefaultFlaggingHandler(FlaggingHandler):
def flag(self, input_data, output_data, flag_option=None, flag_index=None, username=None, flag_path=None):
if flag_path is None:
flag_path = os.path.join(self.app.cwd, self.app.interface.flagging_dir)
log_fp = "{}/log.csv".format(flag_path)
encryption_key = self.app.interface.encryption_key if self.app.interface.encrypt else None
encryption_key = self.interface.encryption_key if self.interface.encrypt else None
is_new = not os.path.exists(log_fp)
if flag_index is None:
csv_data = []
for i, interface in enumerate(self.app.interface.input_components):
for i, interface in enumerate(self.interface.input_components):
csv_data.append(interface.save_flagged(
flag_path, self.app.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
for i, interface in enumerate(self.app.interface.output_components):
flag_path, self.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
for i, interface in enumerate(self.interface.output_components):
csv_data.append(interface.save_flagged(
flag_path, self.app.interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
flag_path, self.interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
output_data[i] is not None else "")
if flag_option is not None:
csv_data.append(flag_option)
@ -56,10 +54,10 @@ class DefaultFlaggingHandler(FlaggingHandler):
csv_data.append(str(datetime.datetime.now()))
if is_new:
headers = [interface["label"]
for interface in self.app.interface.config["input_components"]]
for interface in self.interface.config["input_components"]]
headers += [interface["label"]
for interface in self.app.interface.config["output_components"]]
if self.app.interface.flagging_options is not None:
for interface in self.interface.config["output_components"]]
if self.interface.flagging_options is not None:
headers.append("flag")
if username is not None:
headers.append("username")
@ -76,13 +74,13 @@ class DefaultFlaggingHandler(FlaggingHandler):
writer.writerows(content)
return output.getvalue()
if self.app.interface.encrypt:
if self.interface.encrypt:
output = io.StringIO()
if not is_new:
with open(log_fp, "rb") as csvfile:
encrypted_csv = csvfile.read()
decrypted_csv = encryptor.decrypt(
self.app.interface.encryption_key, encrypted_csv)
self.interface.encryption_key, encrypted_csv)
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
@ -94,7 +92,7 @@ class DefaultFlaggingHandler(FlaggingHandler):
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
csvfile.write(encryptor.encrypt(
self.app.interface.encryption_key, output.getvalue().encode()))
self.interface.encryption_key, output.getvalue().encode()))
else:
if flag_index is None:
with open(log_fp, "a", newline="") as csvfile:

View File

@ -170,6 +170,9 @@ class Interface:
self.simple_server = None
self.allow_screenshot = allow_screenshot
self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging
if self.allow_flagging:
# TODO(abidlabs): If inside a HuggingSpace, then instantiate a HuggingFaceFlaggingHandler() instead
self.flagging_handler = gradio.flagging.DefaultFlaggingHandler(self)
self.flagging_options = flagging_options
self.flagging_dir = flagging_dir
self.encrypt = encrypt

View File

@ -196,10 +196,10 @@ def predict():
output = {"data": prediction, "durations": durations, "avg_durations": avg_durations}
if app.interface.allow_flagging == "auto":
try:
handler = flagging.DefaultFlaggingHandler(app)
flag_index = handler.flag(raw_input, prediction,
flag_index = app.interface.flagging_handler.flag(raw_input, prediction,
flag_option=(None if app.interface.flagging_options is None else ""),
username=current_user.id if current_user.is_authenticated else None)
username=current_user.id if current_user.is_authenticated else None,
flag_path=os.path.join(app.cwd, app.interface.flagging_dir))
output["flag_index"] = flag_index
except Exception as e:
print(str(e))

21
test/test_flagging.py Normal file
View File

@ -0,0 +1,21 @@
import gradio as gr
import tempfile
import unittest
import unittest.mock as mock
class TestFlagging(unittest.TestCase):
def test_num_rows_written(self):
io = gr.Interface(lambda x: x, "text", "text")
io.launch(prevent_thread_lock=True)
with tempfile.TemporaryDirectory() as tmpdirname:
row_count = io.flagging_handler.flag(["test"], ["test"], flag_path=tmpdirname)
self.assertEquals(row_count, 1) # 2 rows written including header
row_count = io.flagging_handler.flag("test", "test", flag_path=tmpdirname)
self.assertEquals(row_count, 2) # 3 rows written including header
io.close()
if __name__ == '__main__':
unittest.main()

View File

@ -5,7 +5,6 @@ import unittest.mock as mock
import ipaddress
import requests
import warnings
import tempfile
from unittest.mock import ANY
import urllib.request
@ -96,6 +95,14 @@ class TestFlaskRoutes(unittest.TestCase):
response = self.client.post('/api/queue/status/', json={"hash": "test"})
self.assertEqual(response.status_code, 200)
def test_flagging_analytics(self):
with mock.patch('requests.post') as mock_post:
with mock.patch('gradio.networking.flag_data') as mock_flag:
response = self.client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
mock_post.assert_called_once()
mock_flag.assert_called_once()
self.assertEqual(response.status_code, 200)
def tearDown(self) -> None:
self.io.close()
gr.reset_all()
@ -144,28 +151,6 @@ class TestInterfaceCustomParameters(unittest.TestCase):
mock_post.assert_not_called()
io.close()
class TestFlagging(unittest.TestCase):
def test_num_rows_written(self):
io = gr.Interface(lambda x: x, "text", "text")
io.launch(prevent_thread_lock=True)
with tempfile.TemporaryDirectory() as tmpdirname:
row_count = networking.flag_data(["test"], ["test"], flag_path=tmpdirname)
self.assertEquals(row_count, 1) # 2 rows written including header
row_count = networking.flag_data("test", "test", flag_path=tmpdirname)
self.assertEquals(row_count, 2) # 3 rows written including header
io.close()
def test_flagging_analytics(self):
io = gr.Interface(lambda x: x, "text", "text")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
with mock.patch('requests.post') as mock_post:
with mock.patch('gradio.networking.flag_data') as mock_flag:
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
mock_post.assert_called_once()
mock_flag.assert_called_once()
self.assertEqual(response.status_code, 200)
io.close()
class TestInterpretation(unittest.TestCase):
def test_interpretation(self):