From abb85c0f9b17be2443b75a6fbf4746a906f0d59e Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Wed, 16 Mar 2022 19:26:39 -0500 Subject: [PATCH 1/8] changes --- gradio/flagging.py | 26 ++++---- gradio/interface.py | 21 ++++--- gradio/utils.py | 4 +- test/test_utils.py | 19 ------ ui/packages/app/src/App.svelte | 2 + ui/packages/app/src/ExampleSet.svelte | 63 +++++++++++++++++-- ui/packages/app/src/Interface.svelte | 16 ++++- .../src/components/input/Number/Number.svelte | 4 +- 8 files changed, 104 insertions(+), 51 deletions(-) diff --git a/gradio/flagging.py b/gradio/flagging.py index 2cf6ceb143..22ea0ba6be 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -98,8 +98,8 @@ class SimpleCSVLogger(FlaggingCallback): ) with open(log_filepath, "a", newline="") as csvfile: - writer = csv.writer(csvfile) - writer.writerow(utils.santize_for_csv(csv_data)) + writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC) + writer.writerow(csv_data) with open(log_filepath, "r") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 @@ -186,7 +186,7 @@ class CSVLogger(FlaggingCallback): content[flag_index][flag_col_index] = flag_option output = io.StringIO() writer = csv.writer(output) - writer.writerows(utils.santize_for_csv(content)) + writer.writerows(content) return output.getvalue() if interface.encrypt: @@ -200,7 +200,7 @@ class CSVLogger(FlaggingCallback): file_content = decrypted_csv.decode() if flag_index is not None: file_content = replace_flag_at_index(file_content) - output.write(utils.santize_for_csv(file_content)) + output.write(file_content) writer = csv.writer(output) if flag_index is None: if is_new: @@ -208,10 +208,8 @@ class CSVLogger(FlaggingCallback): writer.writerow(csv_data) with open(log_fp, "wb") as csvfile: csvfile.write( - utils.santize_for_csv( - encryptor.encrypt( - interface.encryption_key, output.getvalue().encode() - ) + encryptor.encrypt( + interface.encryption_key, output.getvalue().encode() ) ) else: @@ -219,8 +217,8 @@ class CSVLogger(FlaggingCallback): with open(log_fp, "a", newline="") as csvfile: writer = csv.writer(csvfile) if is_new: - writer.writerow(utils.santize_for_csv(headers)) - writer.writerow(utils.santize_for_csv(csv_data)) + writer.writerow(headers) + writer.writerow(csv_data) else: with open(log_fp) as csvfile: file_content = csvfile.read() @@ -228,7 +226,7 @@ class CSVLogger(FlaggingCallback): with open( log_fp, "w", newline="" ) as csvfile: # newline parameter needed for Windows - csvfile.write(utils.santize_for_csv(file_content)) + csvfile.write(file_content) with open(log_fp, "r") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 return line_count @@ -313,7 +311,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): infos = {"flagged": {"features": {}}} with open(self.log_file, "a", newline="") as csvfile: - writer = csv.writer(csvfile) + writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC) # File previews for certain input and output types file_preview_types = { @@ -370,7 +368,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): "_type": "Value", } - writer.writerow(utils.santize_for_csv(headers)) + writer.writerow(headers) # Generate the row corresponding to the flagged sample csv_data = [] @@ -405,7 +403,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): if flag_option is not None: csv_data.append(flag_option) - writer.writerow(utils.santize_for_csv(csv_data)) + writer.writerow(csv_data) if is_new: json.dump(infos, open(self.infos_file, "w")) diff --git a/gradio/interface.py b/gradio/interface.py index ab12ee21d6..2371067060 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -579,19 +579,24 @@ class Interface(Launchable): flag_option="" if self.flagging_options else None, username=username, ) - if self.stateful: - updated_state = prediction[self.state_return_index] - prediction[self.state_return_index] = None - else: - updated_state = None + if self.stateful: + updated_state = prediction[self.state_return_index] + prediction[self.state_return_index] = None + else: + updated_state = None - return { + durations= durations + avg_durations= self.config.get("avg_durations") + response = { "data": prediction, - "durations": durations, - "avg_durations": self.config.get("avg_durations"), "flag_index": flag_index, "updated_state": updated_state, } + if durations is not None: + response["durations"] = durations + if avg_durations is not None: + response["avg_durations"] = avg_durations + return response def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: """ diff --git a/gradio/utils.py b/gradio/utils.py index f7cc1d2a3b..70af9e2ad1 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -289,7 +289,7 @@ def get_default_args(func: Callable) -> Dict[str, Any]: ] -def santize_for_csv(data: str | List[str] | List[List[str]]): +def sanitize_for_csv(data: str | List[str] | List[List[str]]): """Sanitizes data so that it can be safely written to a CSV file.""" def sanitize(item): @@ -320,4 +320,4 @@ def santize_for_csv(data: str | List[str] | List[List[str]]): sanitized_data[outer_index][inner_index] = sanitize(item) return sanitized_data else: - raise ValueError("Unsupported data type: " + str(type(data))) + return data diff --git a/test/test_utils.py b/test/test_utils.py index eed645df62..4f17a68a14 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -15,7 +15,6 @@ from gradio.utils import ( json, launch_analytics, readme_to_html, - santize_for_csv, version_check, ) @@ -117,23 +116,5 @@ class TestIPAddress(unittest.TestCase): self.assertEqual(ip, "No internet connection") -class TestSanitizeForCSV(unittest.TestCase): - def test_safe(self): - safe_data = santize_for_csv("abc") - self.assertEquals(safe_data, "abc") - safe_data = santize_for_csv(["def"]) - self.assertEquals(safe_data, ["def"]) - safe_data = santize_for_csv([["abc"]]) - self.assertEquals(safe_data, [["abc"]]) - - def test_unsafe(self): - safe_data = santize_for_csv("=abc") - self.assertEquals(safe_data, "'=abc") - safe_data = santize_for_csv(["abc", "+abc"]) - self.assertEquals(safe_data, ["abc", "'+abc"]) - safe_data = santize_for_csv([["abc", "=abc"]]) - self.assertEquals(safe_data, [["abc", "'=abc"]]) - - if __name__ == "__main__": unittest.main() diff --git a/ui/packages/app/src/App.svelte b/ui/packages/app/src/App.svelte index d42cd1de61..1502ab43aa 100644 --- a/ui/packages/app/src/App.svelte +++ b/ui/packages/app/src/App.svelte @@ -30,6 +30,7 @@ export let input_components: Array; export let output_components: Array; export let examples: Array>; + export let examples_per_page: number; export let fn: (...args: any) => Promise; export let root: string; export let space: string | undefined = undefined; @@ -66,6 +67,7 @@ {input_components} {output_components} {examples} + {examples_per_page} {theme} {fn} {root} diff --git a/ui/packages/app/src/ExampleSet.svelte b/ui/packages/app/src/ExampleSet.svelte index 317a70c033..825added14 100644 --- a/ui/packages/app/src/ExampleSet.svelte +++ b/ui/packages/app/src/ExampleSet.svelte @@ -15,8 +15,38 @@ export let input_components: Array; export let theme: string; - let selected_examples = examples; + let page = 0; let gallery = input_components.length === 1; + let paginate = examples.length > examples_per_page; + + let selected_examples: Array>; + let page_count: number; + let visible_pages: Array = []; + $: { + if (paginate) { + visible_pages = []; + selected_examples = examples.slice( + page * examples_per_page, + (page + 1) * examples_per_page + ); + page_count = Math.ceil(examples.length / examples_per_page); + [0, page, page_count - 1].forEach((anchor) => { + for (let i = anchor - 2; i <= anchor + 2; i++) { + if (i >= 0 && i < page_count && !visible_pages.includes(i)) { + if ( + visible_pages.length > 0 && + i - visible_pages[visible_pages.length - 1] > 1 + ) { + visible_pages.push(-1); + } + visible_pages.push(i); + } + } + }); + } else { + selected_examples = examples.slice(); + } + }
@@ -31,7 +61,8 @@ {#each selected_examples as example_row, i}
+ {#if paginate} +
+ Pages: + {#each visible_pages as visible_page} + {#if visible_page === -1} +
...
+ {:else} + + {/if} + {/each} +
+ {/if} diff --git a/ui/packages/app/src/Interface.svelte b/ui/packages/app/src/Interface.svelte index 3d1314bd2a..fb7024678d 100644 --- a/ui/packages/app/src/Interface.svelte +++ b/ui/packages/app/src/Interface.svelte @@ -17,6 +17,7 @@ export let theme: string; export let fn: (...args: any) => Promise; export let examples: Array>; + export let examples_per_page: number; export let root: string; export let allow_flagging: string; export let flagging_options: Array | undefined = undefined; @@ -49,8 +50,10 @@ let timer_diff = 0; let avg_duration = Array.isArray(avg_durations) ? avg_durations[0] : null; let expected_duration: number | null = null; + let example_id: number | null = null; const setValues = (index: number, value: unknown) => { + example_id = null; has_changed = true; input_values[index] = value; if (live && state !== "PENDING") { @@ -58,7 +61,8 @@ } }; - const setExampleId = async (example_id: number) => { + const setExampleId = async (_id: number) => { + example_id = _id; input_components.forEach(async (input_component, i) => { const process_example = input_component_map[input_component.name].process_example; @@ -71,6 +75,7 @@ input_values[i] = examples[example_id][i]; } }); + example_id = _id; }; const startTimer = () => { @@ -103,7 +108,12 @@ has_changed = false; let submission_count_at_click = submission_count; startTimer(); - fn("predict", { data: input_values }, queue, queueCallback) + fn( + "predict", + { data: input_values, example_id: example_id }, + queue, + queueCallback + ) .then((output) => { if ( state !== "PENDING" || @@ -351,6 +361,8 @@ {#if examples} number; export let theme: string; - $: setValue(value); setValue(e.target.value)} {theme} /> From e88497e6e8055eb730369db0ba59e23076d5c654 Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Wed, 16 Mar 2022 19:31:20 -0500 Subject: [PATCH 2/8] changes --- gradio/utils.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/gradio/utils.py b/gradio/utils.py index 70af9e2ad1..217f0fc550 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -288,36 +288,3 @@ def get_default_args(func: Callable) -> Dict[str, Any]: for v in signature.parameters.values() ] - -def sanitize_for_csv(data: str | List[str] | List[List[str]]): - """Sanitizes data so that it can be safely written to a CSV file.""" - - def sanitize(item): - return "'" + item - - unsafe_prefixes = ("+", "=", "-", "@") - warning_message = "Sanitizing flagged data by escaping cell contents that begin " - "with one of the following characters: '+', '=', '-', '@'." - - if isinstance(data, str): - if data.startswith(unsafe_prefixes): - warnings.warn(warning_message) - return sanitize(data) - return data - elif isinstance(data, list) and isinstance(data[0], str): - sanitized_data = copy.deepcopy(data) - for index, item in enumerate(data): - if item.startswith(unsafe_prefixes): - warnings.warn(warning_message) - sanitized_data[index] = sanitize(item) - return sanitized_data - elif isinstance(data[0], list) and isinstance(data[0][0], str): - sanitized_data = copy.deepcopy(data) - for outer_index, sublist in enumerate(data): - for inner_index, item in enumerate(sublist): - if item.startswith(unsafe_prefixes): - warnings.warn(warning_message) - sanitized_data[outer_index][inner_index] = sanitize(item) - return sanitized_data - else: - return data From f1e65cba12395d310f974443c27c6094b57cac6b Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 17 Mar 2022 09:32:19 -0700 Subject: [PATCH 3/8] formatting --- ui/packages/app/src/components/input/Number/Number.svelte | 1 - 1 file changed, 1 deletion(-) diff --git a/ui/packages/app/src/components/input/Number/Number.svelte b/ui/packages/app/src/components/input/Number/Number.svelte index 4079fe0575..c4ebf04e15 100644 --- a/ui/packages/app/src/components/input/Number/Number.svelte +++ b/ui/packages/app/src/components/input/Number/Number.svelte @@ -2,7 +2,6 @@ export let value: number; export let setValue: (val: number) => number; export let theme: string; - Date: Thu, 17 Mar 2022 09:38:54 -0700 Subject: [PATCH 4/8] backend formatting --- gradio/interface.py | 4 ++-- gradio/utils.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 2371067060..c805be5a16 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -585,8 +585,8 @@ class Interface(Launchable): else: updated_state = None - durations= durations - avg_durations= self.config.get("avg_durations") + durations = durations + avg_durations = self.config.get("avg_durations") response = { "data": prediction, "flag_index": flag_index, diff --git a/gradio/utils.py b/gradio/utils.py index 217f0fc550..734c55e21c 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -287,4 +287,3 @@ def get_default_args(func: Callable) -> Dict[str, Any]: v.default if v.default is not inspect.Parameter.empty else None for v in signature.parameters.values() ] - From 16009b0ec00cc93f5d4d38116f537c93dbc961ee Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 17 Mar 2022 10:15:56 -0700 Subject: [PATCH 5/8] quoting fix --- gradio/flagging.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gradio/flagging.py b/gradio/flagging.py index 22ea0ba6be..0a1da0216c 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -98,7 +98,7 @@ class SimpleCSVLogger(FlaggingCallback): ) with open(log_filepath, "a", newline="") as csvfile: - writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC) + writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") writer.writerow(csv_data) with open(log_filepath, "r") as csvfile: @@ -185,7 +185,7 @@ class CSVLogger(FlaggingCallback): flag_col_index = header.index("flag") content[flag_index][flag_col_index] = flag_option output = io.StringIO() - writer = csv.writer(output) + writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") writer.writerows(content) return output.getvalue() @@ -201,7 +201,7 @@ class CSVLogger(FlaggingCallback): if flag_index is not None: file_content = replace_flag_at_index(file_content) output.write(file_content) - writer = csv.writer(output) + writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") if flag_index is None: if is_new: writer.writerow(headers) @@ -215,7 +215,7 @@ class CSVLogger(FlaggingCallback): else: if flag_index is None: with open(log_fp, "a", newline="") as csvfile: - writer = csv.writer(csvfile) + writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") if is_new: writer.writerow(headers) writer.writerow(csv_data) @@ -311,7 +311,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): infos = {"flagged": {"features": {}}} with open(self.log_file, "a", newline="") as csvfile: - writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC) + writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") # File previews for certain input and output types file_preview_types = { From f23554c2f5dcb9b182df57621b071e1e7d9a7033 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 17 Mar 2022 10:16:31 -0700 Subject: [PATCH 6/8] backend formatting --- gradio/flagging.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gradio/flagging.py b/gradio/flagging.py index 0a1da0216c..9443dd04ab 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -215,7 +215,9 @@ class CSVLogger(FlaggingCallback): else: if flag_index is None: with open(log_fp, "a", newline="") as csvfile: - writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") + writer = csv.writer( + csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'" + ) if is_new: writer.writerow(headers) writer.writerow(csv_data) From 7a43afa8eefab5cd82e4a7f5f27e438aec4569e2 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 17 Mar 2022 10:30:51 -0700 Subject: [PATCH 7/8] removed quoting for hf dataset saver --- gradio/flagging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio/flagging.py b/gradio/flagging.py index 9443dd04ab..8ad52aa8e0 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -313,7 +313,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): infos = {"flagged": {"features": {}}} with open(self.log_file, "a", newline="") as csvfile: - writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") + writer = csv.writer(csvfile) # File previews for certain input and output types file_preview_types = { From 0eefdaaf37029789c7a853bf019b82b8e7236faf Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 17 Mar 2022 10:39:50 -0700 Subject: [PATCH 8/8] fixed quoting in load --- gradio/process_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio/process_examples.py b/gradio/process_examples.py index 6599c11870..4b1244935b 100644 --- a/gradio/process_examples.py +++ b/gradio/process_examples.py @@ -54,7 +54,7 @@ def cache_interface_examples(interface: Interface) -> None: def load_from_cache(interface: Interface, example_id: int) -> List[Any]: """Loads a particular cached example for the interface.""" with open(CACHE_FILE) as cache: - examples = list(csv.reader(cache)) + examples = list(csv.reader(cache, quotechar="'")) example = examples[example_id + 1] # +1 to adjust for header output = [] for component, cell in zip(interface.output_components, example):