fixed label bugs

This commit is contained in:
Abubakar Abid 2022-03-16 16:09:20 -07:00
parent 10e1f7c8ff
commit b87de1228a
2 changed files with 18 additions and 13 deletions

View File

@ -330,7 +330,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
for i, component in enumerate(interface.input_components): for i, component in enumerate(interface.input_components):
component_label = interface.config["input_components"][i][ component_label = interface.config["input_components"][i][
"label" "label"
] or "Input_{}".format(i) ] or "input_{}".format(i)
headers.append(component_label) headers.append(component_label)
infos["flagged"]["features"][component_label] = { infos["flagged"]["features"][component_label] = {
"dtype": "string", "dtype": "string",
@ -348,7 +348,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
for i, component in enumerate(interface.output_components): for i, component in enumerate(interface.output_components):
component_label = interface.config["output_components"][i][ component_label = interface.config["output_components"][i][
"label" "label"
] or "Output_{}".format(i) ] or "output_{}".format(i)
headers.append(component_label) headers.append(component_label)
infos["flagged"]["features"][component_label] = { infos["flagged"]["features"][component_label] = {
"dtype": "string", "dtype": "string",
@ -377,7 +377,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
for i, component in enumerate(interface.input_components): for i, component in enumerate(interface.input_components):
label = interface.config["input_components"][i][ label = interface.config["input_components"][i][
"label" "label"
] or "Input_{}".format(i) ] or "input_{}".format(i)
filepath = component.save_flagged( filepath = component.save_flagged(
self.dataset_dir, label, input_data[i], None self.dataset_dir, label, input_data[i], None
) )
@ -389,7 +389,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
for i, component in enumerate(interface.output_components): for i, component in enumerate(interface.output_components):
label = interface.config["output_components"][i][ label = interface.config["output_components"][i][
"label" "label"
] or "Output_{}".format(i) ] or "output_{}".format(i)
filepath = ( filepath = (
component.save_flagged( component.save_flagged(
self.dataset_dir, label, output_data[i], None self.dataset_dir, label, output_data[i], None

View File

@ -215,29 +215,34 @@ def get_config_file(interface: Interface) -> Dict[str, Any]:
} }
try: try:
param_names = inspect.getfullargspec(interface.predict[0])[0] param_names = inspect.getfullargspec(interface.predict[0])[0]
for iface, param in zip(config["input_components"], param_names): for i, iface in enumerate(config["input_components"]):
if not iface["label"]: if not iface["label"]:
iface["label"] = param.replace("_", " ") if i < len(param_names):
iface["label"] = param_names[i].replace("_", " ")
else:
iface["label"] = (
f"input {i + 1}"
if len(config["input_components"]) > 1
else "input"
)
for i, iface in enumerate(config["output_components"]): for i, iface in enumerate(config["output_components"]):
outputs_per_function = int( outputs_per_function = int(
len(interface.output_components) / len(interface.predict) len(interface.output_components) / len(interface.predict)
) )
function_index = i // outputs_per_function function_index = i // outputs_per_function
component_index = i - function_index * outputs_per_function component_index = i - function_index * outputs_per_function
ret_name = (
"Output " + str(component_index + 1)
if outputs_per_function > 1
else "Output"
)
if iface["label"] is None: if iface["label"] is None:
iface["label"] = ret_name iface["label"] = (
f"output {component_index + 1}"
if outputs_per_function > 1
else "output"
)
if len(interface.predict) > 1: if len(interface.predict) > 1:
iface["label"] = ( iface["label"] = (
interface.function_names[function_index].replace("_", " ") interface.function_names[function_index].replace("_", " ")
+ ": " + ": "
+ iface["label"] + iface["label"]
) )
except ValueError: except ValueError:
pass pass
if interface.examples is not None: if interface.examples is not None: