Record username when flagging (#4135)

* record username

* fix

* changelog fix

* format

* fix hf saver

* fix deserialization

* fixes
This commit is contained in:
Abubakar Abid 2023-05-10 14:49:49 -05:00 committed by GitHub
parent 2cf13a1c69
commit ccdaac1395
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 19 deletions

View File

@ -6,6 +6,7 @@
## Bug Fixes:
- Records username when flagging by [@abidlabs](https://github.com/abidlabs) in [PR 4135](https://github.com/gradio-app/gradio/pull/4135)
- Fix website build issue by [@aliabd](https://github.com/aliabd) in [PR 4142](https://github.com/gradio-app/gradio/pull/4142)
## Documentation Changes:

View File

@ -8,6 +8,7 @@ import time
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from distutils.version import StrictVersion
from pathlib import Path
from typing import TYPE_CHECKING, Any
@ -113,7 +114,7 @@ class SimpleCSVLogger(FlaggingCallback):
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath) as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
line_count = len(list(csv.reader(csvfile))) - 1
return line_count
@ -187,7 +188,7 @@ class CSVLogger(FlaggingCallback):
writer.writerow(utils.sanitize_list_for_csv(csv_data))
with open(log_filepath, encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
line_count = len(list(csv.reader(csvfile))) - 1
return line_count
@ -286,7 +287,12 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
except huggingface_hub.utils.EntryNotFoundError:
pass
def flag(self, flag_data: list[Any], flag_option: str = "") -> int:
def flag(
self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
if self.separate_dirs:
# JSONL files to support dataset preview on the Hub
unique_id = str(uuid.uuid4())
@ -305,6 +311,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
path_in_repo=path_in_repo,
flag_data=flag_data,
flag_option=flag_option,
username=username or "",
)
def _flag_in_dir(
@ -314,10 +321,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
path_in_repo: str | None,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> int:
# Deserialize components (write images/audio to files)
features, row = self._deserialize_components(
components_dir, flag_data, flag_option
components_dir, flag_data, flag_option, username
)
# Write generic info to dataset_infos.json + upload
@ -394,18 +402,21 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
return data_file.parent.name
def _deserialize_components(
self, data_dir: Path, flag_data: list[Any], flag_option: str = ""
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
# NOTE: not at root level to avoid circular imports
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = {}
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
@ -415,7 +426,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
row.append(Path(deserialized).name)
try:
assert Path(deserialized).exists()
row.append(Path(deserialized).name)
except (AssertionError, TypeError, ValueError):
row.append(str(deserialized))
# If component is eligible for a preview, add the URL of the file
if isinstance(component, tuple(file_preview_types)): # type: ignore
@ -436,7 +451,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
)
)
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row
@ -483,9 +500,11 @@ class FlagMethod:
self.__name__ = "Flag"
self.visual_feedback = visual_feedback
def __call__(self, *flag_data):
def __call__(self, request: gr.Request, *flag_data):
try:
self.flagging_callback.flag(list(flag_data), flag_option=self.value)
self.flagging_callback.flag(
list(flag_data), flag_option=self.value, username=request.username
)
except Exception as e:
print(f"Error while flagging: {e}")
if self.visual_feedback:

View File

@ -611,6 +611,7 @@ def special_args(
updated inputs, progress index, event data index.
"""
signature = inspect.signature(fn)
type_hints = utils.get_type_hints(fn)
positional_args = []
for param in signature.parameters.values():
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
@ -619,19 +620,18 @@ def special_args(
progress_index = None
event_data_index = None
for i, param in enumerate(positional_args):
type_hint = type_hints.get(param.name)
if isinstance(param.default, Progress):
progress_index = i
if inputs is not None:
inputs.insert(i, param.default)
elif param.annotation == routes.Request:
elif type_hint == routes.Request:
if inputs is not None:
inputs.insert(i, request)
elif isinstance(param.annotation, type) and issubclass(
param.annotation, EventData
):
elif type_hint and issubclass(type_hint, EventData):
event_data_index = i
if inputs is not None and event_data is not None:
inputs.insert(i, param.annotation(event_data.target, event_data._data))
inputs.insert(i, type_hint(event_data.target, event_data._data))
elif (
param.default is not param.empty and inputs is not None and len(inputs) <= i
):

View File

@ -634,10 +634,9 @@ class Interface(Blocks):
extra_output = [submit_btn, stop_btn]
cleanup = lambda: [
Button.update(visible=True),
Button.update(visible=False),
]
def cleanup():
return [Button.update(visible=True), Button.update(visible=False)]
for i, trigger in enumerate(triggers):
predict_event = trigger(
lambda: (