From a415e5abc8f98e22ec95726bb2485e2277a2d3f0 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 14 Mar 2022 09:36:51 -0500 Subject: [PATCH] sanitize flagging inputs before writing to csv Former-commit-id: f2d9f808c79abec7f2c8ff4a94b17a9b1cdd5651 --- gradio/flagging.py | 22 +++++++++++----------- gradio/utils.py | 35 ++++++++++++++++++++++++++++++++++- test/test_utils.py | 20 ++++++++++++++++++++ 3 files changed, 65 insertions(+), 12 deletions(-) diff --git a/gradio/flagging.py b/gradio/flagging.py index f34bc3ee5b..89534973c9 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional import gradio as gr -from gradio import encryptor +from gradio import encryptor, utils class FlaggingCallback(ABC): @@ -99,7 +99,7 @@ class SimpleCSVLogger(FlaggingCallback): with open(log_filepath, "a", newline="") as csvfile: writer = csv.writer(csvfile) - writer.writerow(csv_data) + writer.writerow(utils.santize_for_csv(csv_data)) with open(log_filepath, "r") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 @@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback): content[flag_index][flag_col_index] = flag_option output = io.StringIO() writer = csv.writer(output) - writer.writerows(content) + writer.writerows(utils.santize_for_csv(content)) return output.getvalue() if interface.encrypt: @@ -200,25 +200,25 @@ class CSVLogger(FlaggingCallback): file_content = decrypted_csv.decode() if flag_index is not None: file_content = replace_flag_at_index(file_content) - output.write(file_content) + output.write(utils.santize_for_csv(file_content)) writer = csv.writer(output) if flag_index is None: if is_new: writer.writerow(headers) writer.writerow(csv_data) with open(log_fp, "wb") as csvfile: - csvfile.write( + csvfile.write(utils.santize_for_csv( encryptor.encrypt( interface.encryption_key, output.getvalue().encode() - ) + )) ) else: if flag_index is None: with open(log_fp, "a", newline="") as csvfile: writer = csv.writer(csvfile) if is_new: - writer.writerow(headers) - writer.writerow(csv_data) + writer.writerow(utils.santize_for_csv(headers)) + writer.writerow(utils.santize_for_csv(csv_data)) else: with open(log_fp) as csvfile: file_content = csvfile.read() @@ -226,7 +226,7 @@ class CSVLogger(FlaggingCallback): with open( log_fp, "w", newline="" ) as csvfile: # newline parameter needed for Windows - csvfile.write(file_content) + csvfile.write(utils.santize_for_csv(file_content)) with open(log_fp, "r") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 return line_count @@ -368,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): "_type": "Value", } - writer.writerow(headers) + writer.writerow(utils.santize_for_csv(headers)) # Generate the row corresponding to the flagged sample csv_data = [] @@ -403,7 +403,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): if flag_option is not None: csv_data.append(flag_option) - writer.writerow(csv_data) + writer.writerow(utils.santize_for_csv(csv_data)) if is_new: json.dump(infos, open(self.infos_file, "w")) diff --git a/gradio/utils.py b/gradio/utils.py index 03784ba6ea..ef4ff3c2e9 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import copy import csv import inspect import json @@ -10,7 +11,7 @@ import os import random import warnings from distutils.version import StrictVersion -from typing import TYPE_CHECKING, Any, Callable, Dict +from typing import TYPE_CHECKING, Any, Callable, Dict, List import aiohttp import analytics @@ -286,3 +287,35 @@ def get_default_args(func: Callable) -> Dict[str, Any]: v.default if v.default is not inspect.Parameter.empty else None for v in signature.parameters.values() ] + + +def santize_for_csv(data: str | List[str] | List[List[str]]): + """ Sanitizes data so that it can be safely written to a CSV file. """ + def sanitize(item): + return "'" + item + + unsafe_prefixes = ("+", "=", "-", "@") + + if isinstance(data, str): + if data.startswith(unsafe_prefixes): + warnings.warn("Sanitizing flagged data by escaping cell contents") + return sanitize(data) + return data + elif isinstance(data, list) and isinstance(data[0], str): + sanitized_data = copy.deepcopy(data) + for i, item in enumerate(data): + if item.startswith(unsafe_prefixes): + warnings.warn("Sanitizing flagged data by escaping cell contents") + sanitized_data[i] = sanitize(item) + return sanitized_data + elif isinstance(data[0], list) and isinstance(data[0][0], str): + sanitized_data = copy.deepcopy(data) + for s, sublist in enumerate(data): + for i, item in enumerate(sublist): + if item.startswith(unsafe_prefixes): + warnings.warn("Sanitizing flagged data by escaping cell contents") + sanitized_data[s][i] = sanitize(item) + return sanitized_data + else: + raise ValueError("Unsupported data type: " + str(type(data))) + diff --git a/test/test_utils.py b/test/test_utils.py index 4f17a68a14..75476360dd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,6 +16,7 @@ from gradio.utils import ( launch_analytics, readme_to_html, version_check, + santize_for_csv, ) os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -116,5 +117,24 @@ class TestIPAddress(unittest.TestCase): self.assertEqual(ip, "No internet connection") +class TestSanitizeForCSV(unittest.TestCase): + def test_safe(self): + safe_data = santize_for_csv("abc") + self.assertEquals(safe_data, "abc") + safe_data = santize_for_csv(["def"]) + self.assertEquals(safe_data, ["def"]) + safe_data = santize_for_csv([["abc"]]) + self.assertEquals(safe_data, [["abc"]]) + + def test_unsafe(self): + safe_data = santize_for_csv("=abc") + self.assertEquals(safe_data, "'=abc") + safe_data = santize_for_csv(["abc", "+abc"]) + self.assertEquals(safe_data, ["abc", "'+abc"]) + safe_data = santize_for_csv([["abc", "=abc"]]) + self.assertEquals(safe_data, [["abc", "'=abc"]]) + + + if __name__ == "__main__": unittest.main()