sanitize flagging inputs before writing to csv

Former-commit-id: f2d9f808c79abec7f2c8ff4a94b17a9b1cdd5651
This commit is contained in:
Abubakar Abid 2022-03-14 09:36:51 -05:00
parent 81e271ca22
commit a415e5abc8
3 changed files with 65 additions and 12 deletions

View File

@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
from typing import Any, List, Optional
import gradio as gr
from gradio import encryptor
from gradio import encryptor, utils
class FlaggingCallback(ABC):
@ -99,7 +99,7 @@ class SimpleCSVLogger(FlaggingCallback):
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(csv_data)
writer.writerow(utils.santize_for_csv(csv_data))
with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback):
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(content)
writer.writerows(utils.santize_for_csv(content))
return output.getvalue()
if interface.encrypt:
@ -200,25 +200,25 @@ class CSVLogger(FlaggingCallback):
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
output.write(file_content)
output.write(utils.santize_for_csv(file_content))
writer = csv.writer(output)
if flag_index is None:
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
csvfile.write(
csvfile.write(utils.santize_for_csv(
encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()
)
))
)
else:
if flag_index is None:
with open(log_fp, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
writer.writerow(utils.santize_for_csv(headers))
writer.writerow(utils.santize_for_csv(csv_data))
else:
with open(log_fp) as csvfile:
file_content = csvfile.read()
@ -226,7 +226,7 @@ class CSVLogger(FlaggingCallback):
with open(
log_fp, "w", newline=""
) as csvfile: # newline parameter needed for Windows
csvfile.write(file_content)
csvfile.write(utils.santize_for_csv(file_content))
with open(log_fp, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -368,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"_type": "Value",
}
writer.writerow(headers)
writer.writerow(utils.santize_for_csv(headers))
# Generate the row corresponding to the flagged sample
csv_data = []
@ -403,7 +403,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
if flag_option is not None:
csv_data.append(flag_option)
writer.writerow(csv_data)
writer.writerow(utils.santize_for_csv(csv_data))
if is_new:
json.dump(infos, open(self.infos_file, "w"))

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import copy
import csv
import inspect
import json
@ -10,7 +11,7 @@ import os
import random
import warnings
from distutils.version import StrictVersion
from typing import TYPE_CHECKING, Any, Callable, Dict
from typing import TYPE_CHECKING, Any, Callable, Dict, List
import aiohttp
import analytics
@ -286,3 +287,35 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
v.default if v.default is not inspect.Parameter.empty else None
for v in signature.parameters.values()
]
def santize_for_csv(data: str | List[str] | List[List[str]]):
""" Sanitizes data so that it can be safely written to a CSV file. """
def sanitize(item):
return "'" + item
unsafe_prefixes = ("+", "=", "-", "@")
if isinstance(data, str):
if data.startswith(unsafe_prefixes):
warnings.warn("Sanitizing flagged data by escaping cell contents")
return sanitize(data)
return data
elif isinstance(data, list) and isinstance(data[0], str):
sanitized_data = copy.deepcopy(data)
for i, item in enumerate(data):
if item.startswith(unsafe_prefixes):
warnings.warn("Sanitizing flagged data by escaping cell contents")
sanitized_data[i] = sanitize(item)
return sanitized_data
elif isinstance(data[0], list) and isinstance(data[0][0], str):
sanitized_data = copy.deepcopy(data)
for s, sublist in enumerate(data):
for i, item in enumerate(sublist):
if item.startswith(unsafe_prefixes):
warnings.warn("Sanitizing flagged data by escaping cell contents")
sanitized_data[s][i] = sanitize(item)
return sanitized_data
else:
raise ValueError("Unsupported data type: " + str(type(data)))

View File

@ -16,6 +16,7 @@ from gradio.utils import (
launch_analytics,
readme_to_html,
version_check,
santize_for_csv,
)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -116,5 +117,24 @@ class TestIPAddress(unittest.TestCase):
self.assertEqual(ip, "No internet connection")
class TestSanitizeForCSV(unittest.TestCase):
def test_safe(self):
safe_data = santize_for_csv("abc")
self.assertEquals(safe_data, "abc")
safe_data = santize_for_csv(["def"])
self.assertEquals(safe_data, ["def"])
safe_data = santize_for_csv([["abc"]])
self.assertEquals(safe_data, [["abc"]])
def test_unsafe(self):
safe_data = santize_for_csv("=abc")
self.assertEquals(safe_data, "'=abc")
safe_data = santize_for_csv(["abc", "+abc"])
self.assertEquals(safe_data, ["abc", "'+abc"])
safe_data = santize_for_csv([["abc", "=abc"]])
self.assertEquals(safe_data, [["abc", "'=abc"]])
if __name__ == "__main__":
unittest.main()