mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-07 11:46:51 +08:00
Record username when flagging (#4135)
* record username * fix * changelog fix * format * fix hf saver * fix deserialization * fixes
This commit is contained in:
parent
2cf13a1c69
commit
ccdaac1395
@ -6,6 +6,7 @@
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
- Records username when flagging by [@abidlabs](https://github.com/abidlabs) in [PR 4135](https://github.com/gradio-app/gradio/pull/4135)
|
||||
- Fix website build issue by [@aliabd](https://github.com/aliabd) in [PR 4142](https://github.com/gradio-app/gradio/pull/4142)
|
||||
|
||||
## Documentation Changes:
|
||||
|
@ -8,6 +8,7 @@ import time
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from distutils.version import StrictVersion
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -113,7 +114,7 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
||||
|
||||
with open(log_filepath) as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
line_count = len(list(csv.reader(csvfile))) - 1
|
||||
return line_count
|
||||
|
||||
|
||||
@ -187,7 +188,7 @@ class CSVLogger(FlaggingCallback):
|
||||
writer.writerow(utils.sanitize_list_for_csv(csv_data))
|
||||
|
||||
with open(log_filepath, encoding="utf-8") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
line_count = len(list(csv.reader(csvfile))) - 1
|
||||
return line_count
|
||||
|
||||
|
||||
@ -286,7 +287,12 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
pass
|
||||
|
||||
def flag(self, flag_data: list[Any], flag_option: str = "") -> int:
|
||||
def flag(
|
||||
self,
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str | None = None,
|
||||
) -> int:
|
||||
if self.separate_dirs:
|
||||
# JSONL files to support dataset preview on the Hub
|
||||
unique_id = str(uuid.uuid4())
|
||||
@ -305,6 +311,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
path_in_repo=path_in_repo,
|
||||
flag_data=flag_data,
|
||||
flag_option=flag_option,
|
||||
username=username or "",
|
||||
)
|
||||
|
||||
def _flag_in_dir(
|
||||
@ -314,10 +321,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
path_in_repo: str | None,
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str = "",
|
||||
) -> int:
|
||||
# Deserialize components (write images/audio to files)
|
||||
features, row = self._deserialize_components(
|
||||
components_dir, flag_data, flag_option
|
||||
components_dir, flag_data, flag_option, username
|
||||
)
|
||||
|
||||
# Write generic info to dataset_infos.json + upload
|
||||
@ -394,18 +402,21 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
return data_file.parent.name
|
||||
|
||||
def _deserialize_components(
|
||||
self, data_dir: Path, flag_data: list[Any], flag_option: str = ""
|
||||
self,
|
||||
data_dir: Path,
|
||||
flag_data: list[Any],
|
||||
flag_option: str = "",
|
||||
username: str = "",
|
||||
) -> tuple[dict[Any, Any], list[Any]]:
|
||||
"""Deserialize components and return the corresponding row for the flagged sample.
|
||||
|
||||
Images/audio are saved to disk as individual files.
|
||||
"""
|
||||
# Components that can have a preview on dataset repos
|
||||
# NOTE: not at root level to avoid circular imports
|
||||
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
|
||||
|
||||
# Generate the row corresponding to the flagged sample
|
||||
features = {}
|
||||
features = OrderedDict()
|
||||
row = []
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
|
||||
@ -415,7 +426,11 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
|
||||
# Add deserialized object to row
|
||||
features[label] = {"dtype": "string", "_type": "Value"}
|
||||
row.append(Path(deserialized).name)
|
||||
try:
|
||||
assert Path(deserialized).exists()
|
||||
row.append(Path(deserialized).name)
|
||||
except (AssertionError, TypeError, ValueError):
|
||||
row.append(str(deserialized))
|
||||
|
||||
# If component is eligible for a preview, add the URL of the file
|
||||
if isinstance(component, tuple(file_preview_types)): # type: ignore
|
||||
@ -436,7 +451,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
)
|
||||
)
|
||||
features["flag"] = {"dtype": "string", "_type": "Value"}
|
||||
features["username"] = {"dtype": "string", "_type": "Value"}
|
||||
row.append(flag_option)
|
||||
row.append(username)
|
||||
return features, row
|
||||
|
||||
|
||||
@ -483,9 +500,11 @@ class FlagMethod:
|
||||
self.__name__ = "Flag"
|
||||
self.visual_feedback = visual_feedback
|
||||
|
||||
def __call__(self, *flag_data):
|
||||
def __call__(self, request: gr.Request, *flag_data):
|
||||
try:
|
||||
self.flagging_callback.flag(list(flag_data), flag_option=self.value)
|
||||
self.flagging_callback.flag(
|
||||
list(flag_data), flag_option=self.value, username=request.username
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error while flagging: {e}")
|
||||
if self.visual_feedback:
|
||||
|
@ -611,6 +611,7 @@ def special_args(
|
||||
updated inputs, progress index, event data index.
|
||||
"""
|
||||
signature = inspect.signature(fn)
|
||||
type_hints = utils.get_type_hints(fn)
|
||||
positional_args = []
|
||||
for param in signature.parameters.values():
|
||||
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
|
||||
@ -619,19 +620,18 @@ def special_args(
|
||||
progress_index = None
|
||||
event_data_index = None
|
||||
for i, param in enumerate(positional_args):
|
||||
type_hint = type_hints.get(param.name)
|
||||
if isinstance(param.default, Progress):
|
||||
progress_index = i
|
||||
if inputs is not None:
|
||||
inputs.insert(i, param.default)
|
||||
elif param.annotation == routes.Request:
|
||||
elif type_hint == routes.Request:
|
||||
if inputs is not None:
|
||||
inputs.insert(i, request)
|
||||
elif isinstance(param.annotation, type) and issubclass(
|
||||
param.annotation, EventData
|
||||
):
|
||||
elif type_hint and issubclass(type_hint, EventData):
|
||||
event_data_index = i
|
||||
if inputs is not None and event_data is not None:
|
||||
inputs.insert(i, param.annotation(event_data.target, event_data._data))
|
||||
inputs.insert(i, type_hint(event_data.target, event_data._data))
|
||||
elif (
|
||||
param.default is not param.empty and inputs is not None and len(inputs) <= i
|
||||
):
|
||||
|
@ -634,10 +634,9 @@ class Interface(Blocks):
|
||||
|
||||
extra_output = [submit_btn, stop_btn]
|
||||
|
||||
cleanup = lambda: [
|
||||
Button.update(visible=True),
|
||||
Button.update(visible=False),
|
||||
]
|
||||
def cleanup():
|
||||
return [Button.update(visible=True), Button.update(visible=False)]
|
||||
|
||||
for i, trigger in enumerate(triggers):
|
||||
predict_event = trigger(
|
||||
lambda: (
|
||||
|
Loading…
Reference in New Issue
Block a user