mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
sanitize flagging inputs before writing to csv
Former-commit-id: f2d9f808c79abec7f2c8ff4a94b17a9b1cdd5651
This commit is contained in:
parent
81e271ca22
commit
a415e5abc8
@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import gradio as gr
|
||||
from gradio import encryptor
|
||||
from gradio import encryptor, utils
|
||||
|
||||
|
||||
class FlaggingCallback(ABC):
|
||||
@ -99,7 +99,7 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
|
||||
with open(log_filepath, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(csv_data)
|
||||
writer.writerow(utils.santize_for_csv(csv_data))
|
||||
|
||||
with open(log_filepath, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback):
|
||||
content[flag_index][flag_col_index] = flag_option
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerows(content)
|
||||
writer.writerows(utils.santize_for_csv(content))
|
||||
return output.getvalue()
|
||||
|
||||
if interface.encrypt:
|
||||
@ -200,25 +200,25 @@ class CSVLogger(FlaggingCallback):
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
output.write(file_content)
|
||||
output.write(utils.santize_for_csv(file_content))
|
||||
writer = csv.writer(output)
|
||||
if flag_index is None:
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
csvfile.write(
|
||||
csvfile.write(utils.santize_for_csv(
|
||||
encryptor.encrypt(
|
||||
interface.encryption_key, output.getvalue().encode()
|
||||
)
|
||||
))
|
||||
)
|
||||
else:
|
||||
if flag_index is None:
|
||||
with open(log_fp, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
writer.writerow(utils.santize_for_csv(headers))
|
||||
writer.writerow(utils.santize_for_csv(csv_data))
|
||||
else:
|
||||
with open(log_fp) as csvfile:
|
||||
file_content = csvfile.read()
|
||||
@ -226,7 +226,7 @@ class CSVLogger(FlaggingCallback):
|
||||
with open(
|
||||
log_fp, "w", newline=""
|
||||
) as csvfile: # newline parameter needed for Windows
|
||||
csvfile.write(file_content)
|
||||
csvfile.write(utils.santize_for_csv(file_content))
|
||||
with open(log_fp, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
@ -368,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
"_type": "Value",
|
||||
}
|
||||
|
||||
writer.writerow(headers)
|
||||
writer.writerow(utils.santize_for_csv(headers))
|
||||
|
||||
# Generate the row corresponding to the flagged sample
|
||||
csv_data = []
|
||||
@ -403,7 +403,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
|
||||
writer.writerow(csv_data)
|
||||
writer.writerow(utils.santize_for_csv(csv_data))
|
||||
|
||||
if is_new:
|
||||
json.dump(infos, open(self.infos_file, "w"))
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import csv
|
||||
import inspect
|
||||
import json
|
||||
@ -10,7 +11,7 @@ import os
|
||||
import random
|
||||
import warnings
|
||||
from distutils.version import StrictVersion
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
||||
|
||||
import aiohttp
|
||||
import analytics
|
||||
@ -286,3 +287,35 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
|
||||
v.default if v.default is not inspect.Parameter.empty else None
|
||||
for v in signature.parameters.values()
|
||||
]
|
||||
|
||||
|
||||
def santize_for_csv(data: str | List[str] | List[List[str]]):
|
||||
""" Sanitizes data so that it can be safely written to a CSV file. """
|
||||
def sanitize(item):
|
||||
return "'" + item
|
||||
|
||||
unsafe_prefixes = ("+", "=", "-", "@")
|
||||
|
||||
if isinstance(data, str):
|
||||
if data.startswith(unsafe_prefixes):
|
||||
warnings.warn("Sanitizing flagged data by escaping cell contents")
|
||||
return sanitize(data)
|
||||
return data
|
||||
elif isinstance(data, list) and isinstance(data[0], str):
|
||||
sanitized_data = copy.deepcopy(data)
|
||||
for i, item in enumerate(data):
|
||||
if item.startswith(unsafe_prefixes):
|
||||
warnings.warn("Sanitizing flagged data by escaping cell contents")
|
||||
sanitized_data[i] = sanitize(item)
|
||||
return sanitized_data
|
||||
elif isinstance(data[0], list) and isinstance(data[0][0], str):
|
||||
sanitized_data = copy.deepcopy(data)
|
||||
for s, sublist in enumerate(data):
|
||||
for i, item in enumerate(sublist):
|
||||
if item.startswith(unsafe_prefixes):
|
||||
warnings.warn("Sanitizing flagged data by escaping cell contents")
|
||||
sanitized_data[s][i] = sanitize(item)
|
||||
return sanitized_data
|
||||
else:
|
||||
raise ValueError("Unsupported data type: " + str(type(data)))
|
||||
|
||||
|
@ -16,6 +16,7 @@ from gradio.utils import (
|
||||
launch_analytics,
|
||||
readme_to_html,
|
||||
version_check,
|
||||
santize_for_csv,
|
||||
)
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
@ -116,5 +117,24 @@ class TestIPAddress(unittest.TestCase):
|
||||
self.assertEqual(ip, "No internet connection")
|
||||
|
||||
|
||||
class TestSanitizeForCSV(unittest.TestCase):
|
||||
def test_safe(self):
|
||||
safe_data = santize_for_csv("abc")
|
||||
self.assertEquals(safe_data, "abc")
|
||||
safe_data = santize_for_csv(["def"])
|
||||
self.assertEquals(safe_data, ["def"])
|
||||
safe_data = santize_for_csv([["abc"]])
|
||||
self.assertEquals(safe_data, [["abc"]])
|
||||
|
||||
def test_unsafe(self):
|
||||
safe_data = santize_for_csv("=abc")
|
||||
self.assertEquals(safe_data, "'=abc")
|
||||
safe_data = santize_for_csv(["abc", "+abc"])
|
||||
self.assertEquals(safe_data, ["abc", "'+abc"])
|
||||
safe_data = santize_for_csv([["abc", "=abc"]])
|
||||
self.assertEquals(safe_data, [["abc", "'=abc"]])
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user