mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-03 01:50:59 +08:00
597337dcb8
* added playground with 12 demos * change name to recipes, restyle navbar * add explanatory text to page * fix demo mapping * categorize demos, clean up design * styling * cateogry naming and emojis * refactor and add text demos * add view code button * remove opening slash in embed * styling * add image demos * adding plot demos * remove see code button * removed submodules * changes * add audio models * remove fun section * remove tests in image semgentation demo repo * requested changes * add outbreak_forecast * fix broken demos * remove images and models, add new demos * remove readmes, change to run.py, add description as comment * move to /demos folder, clean up dict * add upload_to_spaces script * fix script, clean repos, and add to docker file * fix python versioning issue * env variable * fix * env fixes * spaces instead of tabs * revert to original networking.py * fix rate limiting in asr and autocomplete * change name to demos * clean up navbar * move url and description, remove code comments * add tabs to demos * remove margins and footer from embedded demo * font consistency Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
import gradio as gr
|
|
import random
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
import shap
|
|
import xgboost as xgb
|
|
from datasets import load_dataset
|
|
|
|
|
|
matplotlib.use("Agg")
|
|
dataset = load_dataset("scikit-learn/adult-census-income")
|
|
X_train = dataset["train"].to_pandas()
|
|
_ = X_train.pop("fnlwgt")
|
|
_ = X_train.pop("race")
|
|
y_train = X_train.pop("income")
|
|
y_train = (y_train == ">50K").astype(int)
|
|
categorical_columns = [
|
|
"workclass",
|
|
"education",
|
|
"marital.status",
|
|
"occupation",
|
|
"relationship",
|
|
"sex",
|
|
"native.country",
|
|
]
|
|
X_train = X_train.astype({col: "category" for col in categorical_columns})
|
|
data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
|
|
model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
|
|
explainer = shap.TreeExplainer(model)
|
|
|
|
def predict(*args):
|
|
df = pd.DataFrame([args], columns=X_train.columns)
|
|
df = df.astype({col: "category" for col in categorical_columns})
|
|
pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
|
|
return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
|
|
|
|
|
|
def interpret(*args):
|
|
df = pd.DataFrame([args], columns=X_train.columns)
|
|
df = df.astype({col: "category" for col in categorical_columns})
|
|
shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
|
|
scores_desc = list(zip(shap_values[0], X_train.columns))
|
|
scores_desc = sorted(scores_desc)
|
|
fig_m = plt.figure(tight_layout=True)
|
|
plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
|
|
plt.title("Feature Shap Values")
|
|
plt.ylabel("Shap Value")
|
|
plt.xlabel("Feature")
|
|
plt.tight_layout()
|
|
return fig_m
|
|
|
|
|
|
unique_class = sorted(X_train["workclass"].unique())
|
|
unique_education = sorted(X_train["education"].unique())
|
|
unique_marital_status = sorted(X_train["marital.status"].unique())
|
|
unique_relationship = sorted(X_train["relationship"].unique())
|
|
unique_occupation = sorted(X_train["occupation"].unique())
|
|
unique_sex = sorted(X_train["sex"].unique())
|
|
unique_country = sorted(X_train["native.country"].unique())
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown("""
|
|
**Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
|
|
work_class = gr.Dropdown(
|
|
label="Workclass",
|
|
choices=unique_class,
|
|
value=lambda: random.choice(unique_class),
|
|
)
|
|
education = gr.Dropdown(
|
|
label="Education Level",
|
|
choices=unique_education,
|
|
value=lambda: random.choice(unique_education),
|
|
)
|
|
years = gr.Slider(
|
|
label="Years of schooling",
|
|
minimum=1,
|
|
maximum=16,
|
|
step=1,
|
|
randomize=True,
|
|
)
|
|
marital_status = gr.Dropdown(
|
|
label="Marital Status",
|
|
choices=unique_marital_status,
|
|
value=lambda: random.choice(unique_marital_status),
|
|
)
|
|
occupation = gr.Dropdown(
|
|
label="Occupation",
|
|
choices=unique_occupation,
|
|
value=lambda: random.choice(unique_occupation),
|
|
)
|
|
relationship = gr.Dropdown(
|
|
label="Relationship Status",
|
|
choices=unique_relationship,
|
|
value=lambda: random.choice(unique_relationship),
|
|
)
|
|
sex = gr.Dropdown(
|
|
label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
|
|
)
|
|
capital_gain = gr.Slider(
|
|
label="Capital Gain",
|
|
minimum=0,
|
|
maximum=100000,
|
|
step=500,
|
|
randomize=True,
|
|
)
|
|
capital_loss = gr.Slider(
|
|
label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
|
|
)
|
|
hours_per_week = gr.Slider(
|
|
label="Hours Per Week Worked", minimum=1, maximum=99, step=1
|
|
)
|
|
country = gr.Dropdown(
|
|
label="Native Country",
|
|
choices=unique_country,
|
|
value=lambda: random.choice(unique_country),
|
|
)
|
|
with gr.Column():
|
|
label = gr.Label()
|
|
plot = gr.Plot()
|
|
with gr.Row():
|
|
predict_btn = gr.Button(value="Predict")
|
|
interpret_btn = gr.Button(value="Explain")
|
|
predict_btn.click(
|
|
predict,
|
|
inputs=[
|
|
age,
|
|
work_class,
|
|
education,
|
|
years,
|
|
marital_status,
|
|
occupation,
|
|
relationship,
|
|
sex,
|
|
capital_gain,
|
|
capital_loss,
|
|
hours_per_week,
|
|
country,
|
|
],
|
|
outputs=[label],
|
|
)
|
|
interpret_btn.click(
|
|
interpret,
|
|
inputs=[
|
|
age,
|
|
work_class,
|
|
education,
|
|
years,
|
|
marital_status,
|
|
occupation,
|
|
relationship,
|
|
sex,
|
|
capital_gain,
|
|
capital_loss,
|
|
hours_per_week,
|
|
country,
|
|
],
|
|
outputs=[plot],
|
|
)
|
|
|
|
demo.launch()
|