Fixes cached examples (#1850)

* examples fix

* fix bug

* formatting
This commit is contained in:
Abubakar Abid 2022-07-21 12:57:58 -07:00 committed by GitHub
parent 817819d7a1
commit 6b1de93ab7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 13 deletions

View File

@ -8,6 +8,7 @@ import os
import shutil
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from gradio import utils
from gradio.components import Dataset
from gradio.documentation import document, set_documentation_group
from gradio.flagging import CSVLogger
@ -153,13 +154,13 @@ class Examples:
self.cache_interface_examples()
def load_example(example_id):
processed_example = self.processed_examples[example_id]
if cache_examples:
processed_example += self.load_from_cache(example_id)
if len(processed_example) == 1:
return processed_example[0]
processed_example = self.processed_examples[
example_id
] + self.load_from_cache(example_id)
else:
return processed_example
processed_example = self.processed_examples[example_id]
return utils.resolve_singleton(processed_example)
dataset.click(
load_example,

View File

@ -163,7 +163,7 @@ class CSVLogger(FlaggingCallback):
if self.encryption_key:
output = io.StringIO()
if not is_new:
with open(log_filepath, "rb") as csvfile:
with open(log_filepath, "rb", encoding="utf-8") as csvfile:
encrypted_csv = csvfile.read()
decrypted_csv = encryptor.decrypt(
self.encryption_key, encrypted_csv
@ -177,13 +177,13 @@ class CSVLogger(FlaggingCallback):
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_filepath, "wb") as csvfile:
with open(log_filepath, "wb", encoding="utf-8") as csvfile:
csvfile.write(
encryptor.encrypt(self.encryption_key, output.getvalue().encode())
)
else:
if flag_index is None:
with open(log_filepath, "a", newline="") as csvfile:
with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(
csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
)
@ -191,14 +191,14 @@ class CSVLogger(FlaggingCallback):
writer.writerow(headers)
writer.writerow(csv_data)
else:
with open(log_filepath) as csvfile:
with open(log_filepath, encoding="utf-8") as csvfile:
file_content = csvfile.read()
file_content = replace_flag_at_index(file_content)
with open(
log_filepath, "w", newline=""
log_filepath, "w", newline="", encoding="utf-8"
) as csvfile: # newline parameter needed for Windows
csvfile.write(file_content)
with open(log_filepath, "r") as csvfile:
with open(log_filepath, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -282,7 +282,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
is_new = not os.path.exists(self.log_file)
infos = {"flagged": {"features": {}}}
with open(self.log_file, "a", newline="") as csvfile:
with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
# File previews for certain input and output types
@ -338,7 +338,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
if is_new:
json.dump(infos, open(self.infos_file, "w"))
with open(self.log_file, "r") as csvfile:
with open(self.log_file, "r", encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))