mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-23 11:39:17 +08:00
rewrote DefaultFlaggingHandler not to need app
This commit is contained in:
parent
00eccfd371
commit
27a4b4e637
@ -1,4 +1,4 @@
|
||||
Metadata-Version: 1.0
|
||||
Metadata-Version: 2.1
|
||||
Name: gradio
|
||||
Version: 2.4.0
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
@ -6,6 +6,9 @@ Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid
|
||||
Author-email: a12d@stanford.edu
|
||||
License: Apache License 2.0
|
||||
Description: UNKNOWN
|
||||
Keywords: machine learning,visualization,reproducibility
|
||||
Platform: UNKNOWN
|
||||
License-File: LICENSE
|
||||
|
||||
UNKNOWN
|
||||
|
||||
|
@ -1,15 +1,15 @@
|
||||
Flask-Cors>=3.0.8
|
||||
Flask-Login
|
||||
Flask>=1.1.1
|
||||
analytics-python
|
||||
ffmpy
|
||||
flask-cachebuster
|
||||
markdown2
|
||||
matplotlib
|
||||
numpy
|
||||
pandas
|
||||
paramiko
|
||||
pillow
|
||||
pycryptodome
|
||||
pydub
|
||||
matplotlib
|
||||
pandas
|
||||
pillow
|
||||
ffmpy
|
||||
markdown2
|
||||
pycryptodome
|
||||
requests
|
||||
paramiko
|
||||
analytics-python
|
||||
Flask>=1.1.1
|
||||
Flask-Cors>=3.0.8
|
||||
flask-cachebuster
|
||||
Flask-Login
|
||||
|
@ -10,12 +10,12 @@ class FlaggingHandler():
|
||||
"""
|
||||
A class for defining the methods that any FlaggingHandler should have.
|
||||
"""
|
||||
def __init__(self, app, **kwargs):
|
||||
def __init__(self, interface, **kwargs):
|
||||
"""
|
||||
Parameters:
|
||||
app: Flask app running the interface (in gradio.networking)
|
||||
interface: The interface object that the FlaggingHandler is being used with.
|
||||
"""
|
||||
self.app = app
|
||||
self.interface = interface
|
||||
self.kwargs = kwargs
|
||||
|
||||
def flag(self, input_data, output_data, flag_option=None, flag_index=None, username=None, path=None):
|
||||
@ -34,20 +34,18 @@ class FlaggingHandler():
|
||||
|
||||
class DefaultFlaggingHandler(FlaggingHandler):
|
||||
def flag(self, input_data, output_data, flag_option=None, flag_index=None, username=None, flag_path=None):
|
||||
if flag_path is None:
|
||||
flag_path = os.path.join(self.app.cwd, self.app.interface.flagging_dir)
|
||||
log_fp = "{}/log.csv".format(flag_path)
|
||||
encryption_key = self.app.interface.encryption_key if self.app.interface.encrypt else None
|
||||
encryption_key = self.interface.encryption_key if self.interface.encrypt else None
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
||||
if flag_index is None:
|
||||
csv_data = []
|
||||
for i, interface in enumerate(self.app.interface.input_components):
|
||||
for i, interface in enumerate(self.interface.input_components):
|
||||
csv_data.append(interface.save_flagged(
|
||||
flag_path, self.app.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
|
||||
for i, interface in enumerate(self.app.interface.output_components):
|
||||
flag_path, self.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
|
||||
for i, interface in enumerate(self.interface.output_components):
|
||||
csv_data.append(interface.save_flagged(
|
||||
flag_path, self.app.interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
|
||||
flag_path, self.interface.config["output_components"][i]["label"], output_data[i], encryption_key) if
|
||||
output_data[i] is not None else "")
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
@ -56,10 +54,10 @@ class DefaultFlaggingHandler(FlaggingHandler):
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
if is_new:
|
||||
headers = [interface["label"]
|
||||
for interface in self.app.interface.config["input_components"]]
|
||||
for interface in self.interface.config["input_components"]]
|
||||
headers += [interface["label"]
|
||||
for interface in self.app.interface.config["output_components"]]
|
||||
if self.app.interface.flagging_options is not None:
|
||||
for interface in self.interface.config["output_components"]]
|
||||
if self.interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
if username is not None:
|
||||
headers.append("username")
|
||||
@ -76,13 +74,13 @@ class DefaultFlaggingHandler(FlaggingHandler):
|
||||
writer.writerows(content)
|
||||
return output.getvalue()
|
||||
|
||||
if self.app.interface.encrypt:
|
||||
if self.interface.encrypt:
|
||||
output = io.StringIO()
|
||||
if not is_new:
|
||||
with open(log_fp, "rb") as csvfile:
|
||||
encrypted_csv = csvfile.read()
|
||||
decrypted_csv = encryptor.decrypt(
|
||||
self.app.interface.encryption_key, encrypted_csv)
|
||||
self.interface.encryption_key, encrypted_csv)
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
@ -94,7 +92,7 @@ class DefaultFlaggingHandler(FlaggingHandler):
|
||||
writer.writerow(csv_data)
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
csvfile.write(encryptor.encrypt(
|
||||
self.app.interface.encryption_key, output.getvalue().encode()))
|
||||
self.interface.encryption_key, output.getvalue().encode()))
|
||||
else:
|
||||
if flag_index is None:
|
||||
with open(log_fp, "a", newline="") as csvfile:
|
||||
|
@ -170,6 +170,9 @@ class Interface:
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging
|
||||
if self.allow_flagging:
|
||||
# TODO(abidlabs): If inside a HuggingSpace, then instantiate a HuggingFaceFlaggingHandler() instead
|
||||
self.flagging_handler = gradio.flagging.DefaultFlaggingHandler(self)
|
||||
self.flagging_options = flagging_options
|
||||
self.flagging_dir = flagging_dir
|
||||
self.encrypt = encrypt
|
||||
|
@ -196,10 +196,10 @@ def predict():
|
||||
output = {"data": prediction, "durations": durations, "avg_durations": avg_durations}
|
||||
if app.interface.allow_flagging == "auto":
|
||||
try:
|
||||
handler = flagging.DefaultFlaggingHandler(app)
|
||||
flag_index = handler.flag(raw_input, prediction,
|
||||
flag_index = app.interface.flagging_handler.flag(raw_input, prediction,
|
||||
flag_option=(None if app.interface.flagging_options is None else ""),
|
||||
username=current_user.id if current_user.is_authenticated else None)
|
||||
username=current_user.id if current_user.is_authenticated else None,
|
||||
flag_path=os.path.join(app.cwd, app.interface.flagging_dir))
|
||||
output["flag_index"] = flag_index
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
21
test/test_flagging.py
Normal file
21
test/test_flagging.py
Normal file
@ -0,0 +1,21 @@
|
||||
import gradio as gr
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
|
||||
class TestFlagging(unittest.TestCase):
|
||||
def test_num_rows_written(self):
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
io.launch(prevent_thread_lock=True)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
row_count = io.flagging_handler.flag(["test"], ["test"], flag_path=tmpdirname)
|
||||
self.assertEquals(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_handler.flag("test", "test", flag_path=tmpdirname)
|
||||
self.assertEquals(row_count, 2) # 3 rows written including header
|
||||
io.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -5,7 +5,6 @@ import unittest.mock as mock
|
||||
import ipaddress
|
||||
import requests
|
||||
import warnings
|
||||
import tempfile
|
||||
from unittest.mock import ANY
|
||||
import urllib.request
|
||||
|
||||
@ -96,6 +95,14 @@ class TestFlaskRoutes(unittest.TestCase):
|
||||
response = self.client.post('/api/queue/status/', json={"hash": "test"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_flagging_analytics(self):
|
||||
with mock.patch('requests.post') as mock_post:
|
||||
with mock.patch('gradio.networking.flag_data') as mock_flag:
|
||||
response = self.client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
|
||||
mock_post.assert_called_once()
|
||||
mock_flag.assert_called_once()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
gr.reset_all()
|
||||
@ -144,28 +151,6 @@ class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
mock_post.assert_not_called()
|
||||
io.close()
|
||||
|
||||
class TestFlagging(unittest.TestCase):
|
||||
def test_num_rows_written(self):
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
io.launch(prevent_thread_lock=True)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
row_count = networking.flag_data(["test"], ["test"], flag_path=tmpdirname)
|
||||
self.assertEquals(row_count, 1) # 2 rows written including header
|
||||
row_count = networking.flag_data("test", "test", flag_path=tmpdirname)
|
||||
self.assertEquals(row_count, 2) # 3 rows written including header
|
||||
io.close()
|
||||
|
||||
def test_flagging_analytics(self):
|
||||
io = gr.Interface(lambda x: x, "text", "text")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = app.test_client()
|
||||
with mock.patch('requests.post') as mock_post:
|
||||
with mock.patch('gradio.networking.flag_data') as mock_flag:
|
||||
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
|
||||
mock_post.assert_called_once()
|
||||
mock_flag.assert_called_once()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
io.close()
|
||||
|
||||
class TestInterpretation(unittest.TestCase):
|
||||
def test_interpretation(self):
|
||||
|
Loading…
Reference in New Issue
Block a user