mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-25 12:10:31 +08:00
Merge pull request #827 from gradio-app/examples_bug_fixes
Fixes to examples; use builtin csv sanitization
This commit is contained in:
commit
087de3f067
gradio
test
ui/packages/app/src
@ -98,8 +98,8 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
)
|
||||
|
||||
with open(log_filepath, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(utils.santize_for_csv(csv_data))
|
||||
writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
|
||||
writer.writerow(csv_data)
|
||||
|
||||
with open(log_filepath, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
@ -185,8 +185,8 @@ class CSVLogger(FlaggingCallback):
|
||||
flag_col_index = header.index("flag")
|
||||
content[flag_index][flag_col_index] = flag_option
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerows(utils.santize_for_csv(content))
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
|
||||
writer.writerows(content)
|
||||
return output.getvalue()
|
||||
|
||||
if interface.encrypt:
|
||||
@ -200,27 +200,27 @@ class CSVLogger(FlaggingCallback):
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
output.write(utils.santize_for_csv(file_content))
|
||||
writer = csv.writer(output)
|
||||
output.write(file_content)
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'")
|
||||
if flag_index is None:
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
csvfile.write(
|
||||
utils.santize_for_csv(
|
||||
encryptor.encrypt(
|
||||
interface.encryption_key, output.getvalue().encode()
|
||||
)
|
||||
encryptor.encrypt(
|
||||
interface.encryption_key, output.getvalue().encode()
|
||||
)
|
||||
)
|
||||
else:
|
||||
if flag_index is None:
|
||||
with open(log_fp, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer = csv.writer(
|
||||
csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'"
|
||||
)
|
||||
if is_new:
|
||||
writer.writerow(utils.santize_for_csv(headers))
|
||||
writer.writerow(utils.santize_for_csv(csv_data))
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
else:
|
||||
with open(log_fp) as csvfile:
|
||||
file_content = csvfile.read()
|
||||
@ -228,7 +228,7 @@ class CSVLogger(FlaggingCallback):
|
||||
with open(
|
||||
log_fp, "w", newline=""
|
||||
) as csvfile: # newline parameter needed for Windows
|
||||
csvfile.write(utils.santize_for_csv(file_content))
|
||||
csvfile.write(file_content)
|
||||
with open(log_fp, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
@ -370,7 +370,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
"_type": "Value",
|
||||
}
|
||||
|
||||
writer.writerow(utils.santize_for_csv(headers))
|
||||
writer.writerow(headers)
|
||||
|
||||
# Generate the row corresponding to the flagged sample
|
||||
csv_data = []
|
||||
@ -405,7 +405,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
|
||||
writer.writerow(utils.santize_for_csv(csv_data))
|
||||
writer.writerow(csv_data)
|
||||
|
||||
if is_new:
|
||||
json.dump(infos, open(self.infos_file, "w"))
|
||||
|
@ -579,19 +579,24 @@ class Interface(Launchable):
|
||||
flag_option="" if self.flagging_options else None,
|
||||
username=username,
|
||||
)
|
||||
if self.stateful:
|
||||
updated_state = prediction[self.state_return_index]
|
||||
prediction[self.state_return_index] = None
|
||||
else:
|
||||
updated_state = None
|
||||
if self.stateful:
|
||||
updated_state = prediction[self.state_return_index]
|
||||
prediction[self.state_return_index] = None
|
||||
else:
|
||||
updated_state = None
|
||||
|
||||
return {
|
||||
durations = durations
|
||||
avg_durations = self.config.get("avg_durations")
|
||||
response = {
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"avg_durations": self.config.get("avg_durations"),
|
||||
"flag_index": flag_index,
|
||||
"updated_state": updated_state,
|
||||
}
|
||||
if durations is not None:
|
||||
response["durations"] = durations
|
||||
if avg_durations is not None:
|
||||
response["avg_durations"] = avg_durations
|
||||
return response
|
||||
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
|
@ -54,7 +54,7 @@ def cache_interface_examples(interface: Interface) -> None:
|
||||
def load_from_cache(interface: Interface, example_id: int) -> List[Any]:
|
||||
"""Loads a particular cached example for the interface."""
|
||||
with open(CACHE_FILE) as cache:
|
||||
examples = list(csv.reader(cache))
|
||||
examples = list(csv.reader(cache, quotechar="'"))
|
||||
example = examples[example_id + 1] # +1 to adjust for header
|
||||
output = []
|
||||
for component, cell in zip(interface.output_components, example):
|
||||
|
@ -287,37 +287,3 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
|
||||
v.default if v.default is not inspect.Parameter.empty else None
|
||||
for v in signature.parameters.values()
|
||||
]
|
||||
|
||||
|
||||
def santize_for_csv(data: str | List[str] | List[List[str]]):
|
||||
"""Sanitizes data so that it can be safely written to a CSV file."""
|
||||
|
||||
def sanitize(item):
|
||||
return "'" + item
|
||||
|
||||
unsafe_prefixes = ("+", "=", "-", "@")
|
||||
warning_message = "Sanitizing flagged data by escaping cell contents that begin "
|
||||
"with one of the following characters: '+', '=', '-', '@'."
|
||||
|
||||
if isinstance(data, str):
|
||||
if data.startswith(unsafe_prefixes):
|
||||
warnings.warn(warning_message)
|
||||
return sanitize(data)
|
||||
return data
|
||||
elif isinstance(data, list) and isinstance(data[0], str):
|
||||
sanitized_data = copy.deepcopy(data)
|
||||
for index, item in enumerate(data):
|
||||
if item.startswith(unsafe_prefixes):
|
||||
warnings.warn(warning_message)
|
||||
sanitized_data[index] = sanitize(item)
|
||||
return sanitized_data
|
||||
elif isinstance(data[0], list) and isinstance(data[0][0], str):
|
||||
sanitized_data = copy.deepcopy(data)
|
||||
for outer_index, sublist in enumerate(data):
|
||||
for inner_index, item in enumerate(sublist):
|
||||
if item.startswith(unsafe_prefixes):
|
||||
warnings.warn(warning_message)
|
||||
sanitized_data[outer_index][inner_index] = sanitize(item)
|
||||
return sanitized_data
|
||||
else:
|
||||
raise ValueError("Unsupported data type: " + str(type(data)))
|
||||
|
@ -15,7 +15,6 @@ from gradio.utils import (
|
||||
json,
|
||||
launch_analytics,
|
||||
readme_to_html,
|
||||
santize_for_csv,
|
||||
version_check,
|
||||
)
|
||||
|
||||
@ -117,23 +116,5 @@ class TestIPAddress(unittest.TestCase):
|
||||
self.assertEqual(ip, "No internet connection")
|
||||
|
||||
|
||||
class TestSanitizeForCSV(unittest.TestCase):
|
||||
def test_safe(self):
|
||||
safe_data = santize_for_csv("abc")
|
||||
self.assertEquals(safe_data, "abc")
|
||||
safe_data = santize_for_csv(["def"])
|
||||
self.assertEquals(safe_data, ["def"])
|
||||
safe_data = santize_for_csv([["abc"]])
|
||||
self.assertEquals(safe_data, [["abc"]])
|
||||
|
||||
def test_unsafe(self):
|
||||
safe_data = santize_for_csv("=abc")
|
||||
self.assertEquals(safe_data, "'=abc")
|
||||
safe_data = santize_for_csv(["abc", "+abc"])
|
||||
self.assertEquals(safe_data, ["abc", "'+abc"])
|
||||
safe_data = santize_for_csv([["abc", "=abc"]])
|
||||
self.assertEquals(safe_data, [["abc", "'=abc"]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -30,6 +30,7 @@
|
||||
export let input_components: Array<Component>;
|
||||
export let output_components: Array<Component>;
|
||||
export let examples: Array<Array<unknown>>;
|
||||
export let examples_per_page: number;
|
||||
export let fn: (...args: any) => Promise<unknown>;
|
||||
export let root: string;
|
||||
export let space: string | undefined = undefined;
|
||||
@ -66,6 +67,7 @@
|
||||
{input_components}
|
||||
{output_components}
|
||||
{examples}
|
||||
{examples_per_page}
|
||||
{theme}
|
||||
{fn}
|
||||
{root}
|
||||
|
@ -15,8 +15,38 @@
|
||||
export let input_components: Array<Component>;
|
||||
export let theme: string;
|
||||
|
||||
let selected_examples = examples;
|
||||
let page = 0;
|
||||
let gallery = input_components.length === 1;
|
||||
let paginate = examples.length > examples_per_page;
|
||||
|
||||
let selected_examples: Array<Array<unknown>>;
|
||||
let page_count: number;
|
||||
let visible_pages: Array<number> = [];
|
||||
$: {
|
||||
if (paginate) {
|
||||
visible_pages = [];
|
||||
selected_examples = examples.slice(
|
||||
page * examples_per_page,
|
||||
(page + 1) * examples_per_page
|
||||
);
|
||||
page_count = Math.ceil(examples.length / examples_per_page);
|
||||
[0, page, page_count - 1].forEach((anchor) => {
|
||||
for (let i = anchor - 2; i <= anchor + 2; i++) {
|
||||
if (i >= 0 && i < page_count && !visible_pages.includes(i)) {
|
||||
if (
|
||||
visible_pages.length > 0 &&
|
||||
i - visible_pages[visible_pages.length - 1] > 1
|
||||
) {
|
||||
visible_pages.push(-1);
|
||||
}
|
||||
visible_pages.push(i);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
selected_examples = examples.slice();
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="examples" {theme}>
|
||||
@ -31,7 +61,8 @@
|
||||
{#each selected_examples as example_row, i}
|
||||
<button
|
||||
class="example cursor-pointer p-2 rounded bg-gray-50 dark:bg-gray-700 transition"
|
||||
on:click={() => setExampleId(i)}
|
||||
class:selected={i + page * examples_per_page === example_id}
|
||||
on:click={() => setExampleId(i + page * examples_per_page)}
|
||||
>
|
||||
<svelte:component
|
||||
this={input_component_map[input_components[0].name].example}
|
||||
@ -59,8 +90,8 @@
|
||||
{#each selected_examples as example_row, i}
|
||||
<tr
|
||||
class="cursor-pointer transition"
|
||||
class:selected={i === example_id}
|
||||
on:click={() => setExampleId(i)}
|
||||
class:selected={i + page * examples_per_page === example_id}
|
||||
on:click={() => setExampleId(i + page * examples_per_page)}
|
||||
>
|
||||
{#each example_row as example_cell, j}
|
||||
<td class="py-2 px-4">
|
||||
@ -78,6 +109,24 @@
|
||||
</table>
|
||||
{/if}
|
||||
</div>
|
||||
{#if paginate}
|
||||
<div class="flex gap-2 items-center mt-4">
|
||||
Pages:
|
||||
{#each visible_pages as visible_page}
|
||||
{#if visible_page === -1}
|
||||
<div>...</div>
|
||||
{:else}
|
||||
<button
|
||||
class="page"
|
||||
class:font-bold={page === visible_page}
|
||||
on:click={() => (page = visible_page)}
|
||||
>
|
||||
{visible_page + 1}
|
||||
</button>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<style lang="postcss" global>
|
||||
@ -102,5 +151,11 @@
|
||||
@apply bg-amber-500 text-white;
|
||||
}
|
||||
}
|
||||
.examples-table tr.selected {
|
||||
@apply font-semibold;
|
||||
}
|
||||
.page {
|
||||
@apply py-1 px-2 bg-gray-100 dark:bg-gray-700 rounded;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
@ -17,6 +17,7 @@
|
||||
export let theme: string;
|
||||
export let fn: (...args: any) => Promise<unknown>;
|
||||
export let examples: Array<Array<unknown>>;
|
||||
export let examples_per_page: number;
|
||||
export let root: string;
|
||||
export let allow_flagging: string;
|
||||
export let flagging_options: Array<string> | undefined = undefined;
|
||||
@ -49,8 +50,10 @@
|
||||
let timer_diff = 0;
|
||||
let avg_duration = Array.isArray(avg_durations) ? avg_durations[0] : null;
|
||||
let expected_duration: number | null = null;
|
||||
let example_id: number | null = null;
|
||||
|
||||
const setValues = (index: number, value: unknown) => {
|
||||
example_id = null;
|
||||
has_changed = true;
|
||||
input_values[index] = value;
|
||||
if (live && state !== "PENDING") {
|
||||
@ -58,7 +61,8 @@
|
||||
}
|
||||
};
|
||||
|
||||
const setExampleId = async (example_id: number) => {
|
||||
const setExampleId = async (_id: number) => {
|
||||
example_id = _id;
|
||||
input_components.forEach(async (input_component, i) => {
|
||||
const process_example =
|
||||
input_component_map[input_component.name].process_example;
|
||||
@ -71,6 +75,7 @@
|
||||
input_values[i] = examples[example_id][i];
|
||||
}
|
||||
});
|
||||
example_id = _id;
|
||||
};
|
||||
|
||||
const startTimer = () => {
|
||||
@ -103,7 +108,12 @@
|
||||
has_changed = false;
|
||||
let submission_count_at_click = submission_count;
|
||||
startTimer();
|
||||
fn("predict", { data: input_values }, queue, queueCallback)
|
||||
fn(
|
||||
"predict",
|
||||
{ data: input_values, example_id: example_id },
|
||||
queue,
|
||||
queueCallback
|
||||
)
|
||||
.then((output) => {
|
||||
if (
|
||||
state !== "PENDING" ||
|
||||
@ -351,6 +361,8 @@
|
||||
{#if examples}
|
||||
<ExampleSet
|
||||
{examples}
|
||||
{examples_per_page}
|
||||
{example_id}
|
||||
{input_components}
|
||||
{theme}
|
||||
{examples_dir}
|
||||
|
@ -2,14 +2,13 @@
|
||||
export let value: number;
|
||||
export let setValue: (val: number) => number;
|
||||
export let theme: string;
|
||||
|
||||
$: setValue(value);
|
||||
</script>
|
||||
|
||||
<input
|
||||
type="number"
|
||||
class="input-number w-full rounded box-border p-2 focus:outline-none appearance-none"
|
||||
bind:value
|
||||
{value}
|
||||
on:input={(e) => setValue(e.target.value)}
|
||||
{theme}
|
||||
/>
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user