flag username, allow changing floag option on auto

This commit is contained in:
Ali Abid 2021-07-06 15:07:14 -07:00
parent 4c2efc54a8
commit 2131caf642
2 changed files with 85 additions and 40 deletions

View File

@ -41,6 +41,7 @@ export class GradioInterface extends React.Component {
state["just_flagged"] = false;
state["has_changed"] = false;
state["example_id"] = null;
state["flag_index"] = null;
return state;
}
clear() {
@ -54,13 +55,19 @@ export class GradioInterface extends React.Component {
}
input_state[i] = this.state[i];
}
this.setState({ "submitting": true, "has_changed": false, "error": false });
this.setState({ "submitting": true, "has_changed": false, "error": false, "flag_index": null });
this.props.fn(input_state, "predict").then(output => {
let index_start = this.props.input_components.length;
let new_state = {};
for (let [i, value] of output["data"].entries()) {
this.setState({ [index_start + i]: value });
new_state[index_start + i] = value;
}
this.setState({ "submitting": false, "complete": true });
if (output["flag_index"] !== null) {
new_state["flag_index"] = output["flag_index"];
}
new_state["submitting"] = false
new_state["complete"] = true
this.setState(new_state)
if (this.props.live && this.state.has_changed) {
this.submit();
}
@ -77,11 +84,15 @@ export class GradioInterface extends React.Component {
return;
}
let component_state = { "input_data": [], "output_data": [] };
for (let i = 0; i < this.props.input_components.length; i++) {
component_state["input_data"].push(this.state[i]);
}
for (let i = 0; i < this.props.output_components.length; i++) {
component_state["output_data"].push(this.state[this.props.input_components.length + i]);
if (this.state.flag_index !== undefined) {
component_state["flag_index"] = this.state.flag_index;
} else {
for (let i = 0; i < this.props.input_components.length; i++) {
component_state["input_data"].push(this.state[i]);
}
for (let i = 0; i < this.props.output_components.length; i++) {
component_state["output_data"].push(this.state[this.props.input_components.length + i]);
}
}
this.setState({ "just_flagged": true });
window.setTimeout(() => {

View File

@ -179,8 +179,12 @@ def predict():
output = {"data": prediction, "durations": durations}
if app.interface.allow_flagging == "auto":
try:
flag_data(raw_input, prediction)
except:
flag_index = flag_data(raw_input, prediction,
flag_option=(None if app.interface.flagging_options is None else ""),
username=current_user.id if current_user.is_authenticated else None)
output["flag_index"] = flag_index
except Exception as e:
print(str(e))
pass
return jsonify(output)
@ -274,29 +278,45 @@ def predict_examples():
return jsonify(output)
def flag_data(input_data, output_data, flag_option=None):
def flag_data(input_data, output_data, flag_option=None, flag_index=None, username=None):
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
csv_data = []
encryption_key = app.interface.encryption_key if app.interface.encrypt else None
for i, interface in enumerate(app.interface.input_components):
csv_data.append(interface.save_flagged(
flag_path, app.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
for i, interface in enumerate(app.interface.output_components):
csv_data.append(interface.save_flagged(
flag_path, app.interface.config["output_components"][i]["label"], output_data[i], encryption_key))
if flag_option:
csv_data.append(flag_option)
headers = [interface["label"]
for interface in app.interface.config["input_components"]]
headers += [interface["label"]
for interface in app.interface.config["output_components"]]
if app.interface.flagging_options is not None:
headers.append("flag")
log_fp = "{}/log.csv".format(flag_path)
encryption_key = app.interface.encryption_key if app.interface.encrypt else None
is_new = not os.path.exists(log_fp)
if flag_index is None:
csv_data = []
for i, interface in enumerate(app.interface.input_components):
csv_data.append(interface.save_flagged(
flag_path, app.interface.config["input_components"][i]["label"], input_data[i], encryption_key))
for i, interface in enumerate(app.interface.output_components):
csv_data.append(interface.save_flagged(
flag_path, app.interface.config["output_components"][i]["label"], output_data[i], encryption_key))
if flag_option is not None:
csv_data.append(flag_option)
if username is not None:
csv_data.append(username)
if is_new:
headers = [interface["label"]
for interface in app.interface.config["input_components"]]
headers += [interface["label"]
for interface in app.interface.config["output_components"]]
if app.interface.flagging_options is not None:
headers.append("flag")
if username is not None:
headers.append("username")
def replace_flag_at_index(file_content):
file_content = io.StringIO(file_content)
content = list(csv.reader(file_content))
header = content[0]
flag_col_index = header.index("flag")
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(content)
return output.getvalue()
if app.interface.encrypt:
output = io.StringIO()
if not is_new:
@ -304,28 +324,42 @@ def flag_data(input_data, output_data, flag_option=None):
encrypted_csv = csvfile.read()
decrypted_csv = encryptor.decrypt(
app.interface.encryption_key, encrypted_csv)
output.write(decrypted_csv.decode())
file_content = decrypted_csv.decode()
if flag_index is not None:
file_content = replace_flag_at_index(file_content)
output.write(file_content)
writer = csv.writer(output)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
if flag_index is None:
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
csvfile.write(encryptor.encrypt(
app.interface.encryption_key, output.getvalue().encode()))
else:
with open(log_fp, "a") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
if flag_index is None:
with open(log_fp, "a") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
else:
with open(log_fp) as csvfile:
file_content = csvfile.read()
file_content = replace_flag_at_index(file_content)
with open(log_fp, "w") as csvfile:
csvfile.write(file_content)
with open(log_fp, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count
@app.route("/api/flag/", methods=["POST"])
@login_check
def flag():
log_feature_analytics('flag')
data = request.json['data']
flag_data(data['input_data'], data['output_data'], data.get("flag_option"))
flag_data(data['input_data'], data['output_data'], data.get("flag_option"), data.get("flag_index"),
current_user.id if current_user.is_authenticated else None)
return jsonify(success=True)