mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
flag username, allow changing floag option on auto
This commit is contained in:
parent
4c2efc54a8
commit
2131caf642
@ -41,6 +41,7 @@ export class GradioInterface extends React.Component {
|
||||
state["just_flagged"] = false;
|
||||
state["has_changed"] = false;
|
||||
state["example_id"] = null;
|
||||
state["flag_index"] = null;
|
||||
return state;
|
||||
}
|
||||
clear() {
|
||||
@ -54,13 +55,19 @@ export class GradioInterface extends React.Component {
|
||||
}
|
||||
input_state[i] = this.state[i];
|
||||
}
|
||||
this.setState({ "submitting": true, "has_changed": false, "error": false });
|
||||
this.setState({ "submitting": true, "has_changed": false, "error": false, "flag_index": null });
|
||||
this.props.fn(input_state, "predict").then(output => {
|
||||
let index_start = this.props.input_components.length;
|
||||
let new_state = {};
|
||||
for (let [i, value] of output["data"].entries()) {
|
||||
this.setState({ [index_start + i]: value });
|
||||
new_state[index_start + i] = value;
|
||||
}
|
||||
this.setState({ "submitting": false, "complete": true });
|
||||
if (output["flag_index"] !== null) {
|
||||
new_state["flag_index"] = output["flag_index"];
|
||||
}
|
||||
new_state["submitting"] = false
|
||||
new_state["complete"] = true
|
||||
this.setState(new_state)
|
||||
if (this.props.live && this.state.has_changed) {
|
||||
this.submit();
|
||||
}
|
||||
@ -77,11 +84,15 @@ export class GradioInterface extends React.Component {
|
||||
return;
|
||||
}
|
||||
let component_state = { "input_data": [], "output_data": [] };
|
||||
for (let i = 0; i < this.props.input_components.length; i++) {
|
||||
component_state["input_data"].push(this.state[i]);
|
||||
}
|
||||
for (let i = 0; i < this.props.output_components.length; i++) {
|
||||
component_state["output_data"].push(this.state[this.props.input_components.length + i]);
|
||||
if (this.state.flag_index !== undefined) {
|
||||
component_state["flag_index"] = this.state.flag_index;
|
||||
} else {
|
||||
for (let i = 0; i < this.props.input_components.length; i++) {
|
||||
component_state["input_data"].push(this.state[i]);
|
||||
}
|
||||
for (let i = 0; i < this.props.output_components.length; i++) {
|
||||
component_state["output_data"].push(this.state[this.props.input_components.length + i]);
|
||||
}
|
||||
}
|
||||
this.setState({ "just_flagged": true });
|
||||
window.setTimeout(() => {
|
||||
|
@ -179,8 +179,12 @@ def predict():
|
||||
output = {"data": prediction, "durations": durations}
|
||||
if app.interface.allow_flagging == "auto":
|
||||
try:
|
||||
flag_data(raw_input, prediction)
|
||||
except:
|
||||
flag_index = flag_data(raw_input, prediction,
|
||||
flag_option=(None if app.interface.flagging_options is None else ""),
|
||||
username=current_user.id if current_user.is_authenticated else None)
|
||||
output["flag_index"] = flag_index
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
pass
|
||||
return jsonify(output)
|
||||
|
||||
@ -274,29 +278,45 @@ def predict_examples():
|
||||
return jsonify(output)
|
||||
|
||||
|
||||
def flag_data(input_data, output_data, flag_option=None):
|
||||
def flag_data(input_data, output_data, flag_option=None, flag_index=None, username=None):
|
||||
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
|
||||
csv_data = []
|
||||
encryption_key = app.interface.encryption_key if app.interface.encrypt else None
|
||||
for i, interface in enumerate(app.interface.input_components):
|
||||
csv_data.append(interface.save_flagged(
|
||||
flag_path, app.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
|
||||
for i, interface in enumerate(app.interface.output_components):
|
||||
csv_data.append(interface.save_flagged(
|
||||
flag_path, app.interface.config["output_components"][i]["label"], output_data[i], encryption_key))
|
||||
if flag_option:
|
||||
csv_data.append(flag_option)
|
||||
|
||||
headers = [interface["label"]
|
||||
for interface in app.interface.config["input_components"]]
|
||||
headers += [interface["label"]
|
||||
for interface in app.interface.config["output_components"]]
|
||||
if app.interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
|
||||
log_fp = "{}/log.csv".format(flag_path)
|
||||
encryption_key = app.interface.encryption_key if app.interface.encrypt else None
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
||||
if flag_index is None:
|
||||
csv_data = []
|
||||
for i, interface in enumerate(app.interface.input_components):
|
||||
csv_data.append(interface.save_flagged(
|
||||
flag_path, app.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
|
||||
for i, interface in enumerate(app.interface.output_components):
|
||||
csv_data.append(interface.save_flagged(
|
||||
flag_path, app.interface.config["output_components"][i]["label"], output_data[i], encryption_key))
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
if username is not None:
|
||||
csv_data.append(username)
|
||||
if is_new:
|
||||
headers = [interface["label"]
|
||||
for interface in app.interface.config["input_components"]]
|
||||
headers += [interface["label"]
|
||||
for interface in app.interface.config["output_components"]]
|
||||
if app.interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
if username is not None:
|
||||
headers.append("username")
|
||||
|
||||
def replace_flag_at_index(file_content):
|
||||
file_content = io.StringIO(file_content)
|
||||
content = list(csv.reader(file_content))
|
||||
header = content[0]
|
||||
flag_col_index = header.index("flag")
|
||||
content[flag_index][flag_col_index] = flag_option
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerows(content)
|
||||
return output.getvalue()
|
||||
|
||||
if app.interface.encrypt:
|
||||
output = io.StringIO()
|
||||
if not is_new:
|
||||
@ -304,28 +324,42 @@ def flag_data(input_data, output_data, flag_option=None):
|
||||
encrypted_csv = csvfile.read()
|
||||
decrypted_csv = encryptor.decrypt(
|
||||
app.interface.encryption_key, encrypted_csv)
|
||||
output.write(decrypted_csv.decode())
|
||||
file_content = decrypted_csv.decode()
|
||||
if flag_index is not None:
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
output.write(file_content)
|
||||
writer = csv.writer(output)
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
if flag_index is None:
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
csvfile.write(encryptor.encrypt(
|
||||
app.interface.encryption_key, output.getvalue().encode()))
|
||||
else:
|
||||
with open(log_fp, "a") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
|
||||
if flag_index is None:
|
||||
with open(log_fp, "a") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
else:
|
||||
with open(log_fp) as csvfile:
|
||||
file_content = csvfile.read()
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
with open(log_fp, "w") as csvfile:
|
||||
csvfile.write(file_content)
|
||||
with open(log_fp, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
@login_check
|
||||
def flag():
|
||||
log_feature_analytics('flag')
|
||||
data = request.json['data']
|
||||
flag_data(data['input_data'], data['output_data'], data.get("flag_option"))
|
||||
flag_data(data['input_data'], data['output_data'], data.get("flag_option"), data.get("flag_index"),
|
||||
current_user.id if current_user.is_authenticated else None)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user