2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-03-25 12:10:31 +08:00

Merge pull request from gradio-app/examples_bug_fixes

Fixes to examples; use builtin csv sanitization
This commit is contained in:
Abubakar Abid 2022-03-17 10:46:07 -07:00 committed by GitHub
commit 087de3f067
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 107 additions and 87 deletions

@ -98,8 +98,8 @@ class SimpleCSVLogger(FlaggingCallback):
)
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(utils.santize_for_csv(csv_data))
writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer.writerow(csv_data)
with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
@ -185,8 +185,8 @@ class CSVLogger(FlaggingCallback):
flag_col_index = header.index("flag")
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(utils.santize_for_csv(content))
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
writer.writerows(content)
return output.getvalue()
if interface.encrypt:
@ -200,27 +200,27 @@ class CSVLogger(FlaggingCallback):
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
output.write(utils.santize_for_csv(file_content))
writer = csv.writer(output)
output.write(file_content)
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
if flag_index is None:
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
csvfile.write(
utils.santize_for_csv(
encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()
)
encryptor.encrypt(
interface.encryption_key, output.getvalue().encode()
)
)
else:
if flag_index is None:
with open(log_fp, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
writer = csv.writer(
csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
)
if is_new:
writer.writerow(utils.santize_for_csv(headers))
writer.writerow(utils.santize_for_csv(csv_data))
writer.writerow(headers)
writer.writerow(csv_data)
else:
with open(log_fp) as csvfile:
file_content = csvfile.read()
@ -228,7 +228,7 @@ class CSVLogger(FlaggingCallback):
with open(
log_fp, "w", newline=""
) as csvfile: # newline parameter needed for Windows
csvfile.write(utils.santize_for_csv(file_content))
csvfile.write(file_content)
with open(log_fp, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@ -370,7 +370,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"_type": "Value",
}
writer.writerow(utils.santize_for_csv(headers))
writer.writerow(headers)
# Generate the row corresponding to the flagged sample
csv_data = []
@ -405,7 +405,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
if flag_option is not None:
csv_data.append(flag_option)
writer.writerow(utils.santize_for_csv(csv_data))
writer.writerow(csv_data)
if is_new:
json.dump(infos, open(self.infos_file, "w"))

@ -579,19 +579,24 @@ class Interface(Launchable):
flag_option="" if self.flagging_options else None,
username=username,
)
if self.stateful:
updated_state = prediction[self.state_return_index]
prediction[self.state_return_index] = None
else:
updated_state = None
if self.stateful:
updated_state = prediction[self.state_return_index]
prediction[self.state_return_index] = None
else:
updated_state = None
return {
durations = durations
avg_durations = self.config.get("avg_durations")
response = {
"data": prediction,
"durations": durations,
"avg_durations": self.config.get("avg_durations"),
"flag_index": flag_index,
"updated_state": updated_state,
}
if durations is not None:
response["durations"] = durations
if avg_durations is not None:
response["avg_durations"] = avg_durations
return response
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
"""

@ -54,7 +54,7 @@ def cache_interface_examples(interface: Interface) -> None:
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
"""Loads a particular cached example for the interface."""
with open(CACHE_FILE) as cache:
examples = list(csv.reader(cache))
examples = list(csv.reader(cache, quotechar="'"))
example = examples[example_id + 1] # +1 to adjust for header
output = []
for component, cell in zip(interface.output_components, example):

@ -287,37 +287,3 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
v.default if v.default is not inspect.Parameter.empty else None
for v in signature.parameters.values()
]
def santize_for_csv(data: str | List[str] | List[List[str]]):
"""Sanitizes data so that it can be safely written to a CSV file."""
def sanitize(item):
return "'" + item
unsafe_prefixes = ("+", "=", "-", "@")
warning_message = "Sanitizing flagged data by escaping cell contents that begin "
"with one of the following characters: '+', '=', '-', '@'."
if isinstance(data, str):
if data.startswith(unsafe_prefixes):
warnings.warn(warning_message)
return sanitize(data)
return data
elif isinstance(data, list) and isinstance(data[0], str):
sanitized_data = copy.deepcopy(data)
for index, item in enumerate(data):
if item.startswith(unsafe_prefixes):
warnings.warn(warning_message)
sanitized_data[index] = sanitize(item)
return sanitized_data
elif isinstance(data[0], list) and isinstance(data[0][0], str):
sanitized_data = copy.deepcopy(data)
for outer_index, sublist in enumerate(data):
for inner_index, item in enumerate(sublist):
if item.startswith(unsafe_prefixes):
warnings.warn(warning_message)
sanitized_data[outer_index][inner_index] = sanitize(item)
return sanitized_data
else:
raise ValueError("Unsupported data type: " + str(type(data)))

@ -15,7 +15,6 @@ from gradio.utils import (
json,
launch_analytics,
readme_to_html,
santize_for_csv,
version_check,
)
@ -117,23 +116,5 @@ class TestIPAddress(unittest.TestCase):
self.assertEqual(ip, "No internet connection")
class TestSanitizeForCSV(unittest.TestCase):
def test_safe(self):
safe_data = santize_for_csv("abc")
self.assertEquals(safe_data, "abc")
safe_data = santize_for_csv(["def"])
self.assertEquals(safe_data, ["def"])
safe_data = santize_for_csv([["abc"]])
self.assertEquals(safe_data, [["abc"]])
def test_unsafe(self):
safe_data = santize_for_csv("=abc")
self.assertEquals(safe_data, "'=abc")
safe_data = santize_for_csv(["abc", "+abc"])
self.assertEquals(safe_data, ["abc", "'+abc"])
safe_data = santize_for_csv([["abc", "=abc"]])
self.assertEquals(safe_data, [["abc", "'=abc"]])
if __name__ == "__main__":
unittest.main()

@ -30,6 +30,7 @@
export let input_components: Array<Component>;
export let output_components: Array<Component>;
export let examples: Array<Array<unknown>>;
export let examples_per_page: number;
export let fn: (...args: any) => Promise<unknown>;
export let root: string;
export let space: string | undefined = undefined;
@ -66,6 +67,7 @@
{input_components}
{output_components}
{examples}
{examples_per_page}
{theme}
{fn}
{root}

@ -15,8 +15,38 @@
export let input_components: Array<Component>;
export let theme: string;
let selected_examples = examples;
let page = 0;
let gallery = input_components.length === 1;
let paginate = examples.length > examples_per_page;
let selected_examples: Array<Array<unknown>>;
let page_count: number;
let visible_pages: Array<number> = [];
$: {
if (paginate) {
visible_pages = [];
selected_examples = examples.slice(
page * examples_per_page,
(page + 1) * examples_per_page
);
page_count = Math.ceil(examples.length / examples_per_page);
[0, page, page_count - 1].forEach((anchor) => {
for (let i = anchor - 2; i <= anchor + 2; i++) {
if (i >= 0 && i < page_count && !visible_pages.includes(i)) {
if (
visible_pages.length > 0 &&
i - visible_pages[visible_pages.length - 1] > 1
) {
visible_pages.push(-1);
}
visible_pages.push(i);
}
}
});
} else {
selected_examples = examples.slice();
}
}
</script>
<div class="examples" {theme}>
@ -31,7 +61,8 @@
{#each selected_examples as example_row, i}
<button
class="example cursor-pointer p-2 rounded bg-gray-50 dark:bg-gray-700 transition"
on:click={() => setExampleId(i)}
class:selected={i + page * examples_per_page === example_id}
on:click={() => setExampleId(i + page * examples_per_page)}
>
<svelte:component
this={input_component_map[input_components[0].name].example}
@ -59,8 +90,8 @@
{#each selected_examples as example_row, i}
<tr
class="cursor-pointer transition"
class:selected={i === example_id}
on:click={() => setExampleId(i)}
class:selected={i + page * examples_per_page === example_id}
on:click={() => setExampleId(i + page * examples_per_page)}
>
{#each example_row as example_cell, j}
<td class="py-2 px-4">
@ -78,6 +109,24 @@
</table>
{/if}
</div>
{#if paginate}
<div class="flex gap-2 items-center mt-4">
Pages:
{#each visible_pages as visible_page}
{#if visible_page === -1}
<div>...</div>
{:else}
<button
class="page"
class:font-bold={page === visible_page}
on:click={() => (page = visible_page)}
>
{visible_page + 1}
</button>
{/if}
{/each}
</div>
{/if}
</div>
<style lang="postcss" global>
@ -102,5 +151,11 @@
@apply bg-amber-500 text-white;
}
}
.examples-table tr.selected {
@apply font-semibold;
}
.page {
@apply py-1 px-2 bg-gray-100 dark:bg-gray-700 rounded;
}
}
</style>

@ -17,6 +17,7 @@
export let theme: string;
export let fn: (...args: any) => Promise<unknown>;
export let examples: Array<Array<unknown>>;
export let examples_per_page: number;
export let root: string;
export let allow_flagging: string;
export let flagging_options: Array<string> | undefined = undefined;
@ -49,8 +50,10 @@
let timer_diff = 0;
let avg_duration = Array.isArray(avg_durations) ? avg_durations[0] : null;
let expected_duration: number | null = null;
let example_id: number | null = null;
const setValues = (index: number, value: unknown) => {
example_id = null;
has_changed = true;
input_values[index] = value;
if (live && state !== "PENDING") {
@ -58,7 +61,8 @@
}
};
const setExampleId = async (example_id: number) => {
const setExampleId = async (_id: number) => {
example_id = _id;
input_components.forEach(async (input_component, i) => {
const process_example =
input_component_map[input_component.name].process_example;
@ -71,6 +75,7 @@
input_values[i] = examples[example_id][i];
}
});
example_id = _id;
};
const startTimer = () => {
@ -103,7 +108,12 @@
has_changed = false;
let submission_count_at_click = submission_count;
startTimer();
fn("predict", { data: input_values }, queue, queueCallback)
fn(
"predict",
{ data: input_values, example_id: example_id },
queue,
queueCallback
)
.then((output) => {
if (
state !== "PENDING" ||
@ -351,6 +361,8 @@
{#if examples}
<ExampleSet
{examples}
{examples_per_page}
{example_id}
{input_components}
{theme}
{examples_dir}

@ -2,14 +2,13 @@
export let value: number;
export let setValue: (val: number) => number;
export let theme: string;
$: setValue(value);
</script>
<input
type="number"
class="input-number w-full rounded box-border p-2 focus:outline-none appearance-none"
bind:value
{value}
on:input={(e) => setValue(e.target.value)}
{theme}
/>