mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Refactor plots to drop altair
and use vega.js
directly (#8807)
* changes * add changeset * changes * changes * changes * add changeset * changes * add changeset * changes * add changeset * add changeset * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * add changeset * changes * changes * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/blocks.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes --------- Co-authored-by: Ali Abid <aliabid94@gmail.com> Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
914b1935de
commit
a238af4d68
8
.changeset/tangy-beds-guess.md
Normal file
8
.changeset/tangy-beds-guess.md
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"@gradio/datetime": minor
|
||||
"@gradio/nativeplot": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Refactor plots to drop `altair` and use `vega.js` directly
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks(fill_width=True) as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -10,7 +10,7 @@ def xray_model(diseases, img):
|
||||
def ct_model(diseases, img):
|
||||
return [{disease: 0.1 for disease in diseases}]
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Blocks(fill_width=True) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Detect Disease From Scan
|
||||
|
@ -1,111 +1,77 @@
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from data import temp_sensor_data, food_rating_data
|
||||
|
||||
from vega_datasets import data
|
||||
with gr.Blocks() as bar_plots:
|
||||
with gr.Row():
|
||||
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
|
||||
end = gr.DateTime("2021-01-05 00:00:00", label="End")
|
||||
apply_btn = gr.Button("Apply", scale=0)
|
||||
with gr.Row():
|
||||
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
|
||||
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")
|
||||
|
||||
barley = data.barley()
|
||||
simple = pd.DataFrame({
|
||||
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
|
||||
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
|
||||
})
|
||||
|
||||
def bar_plot_fn(display):
|
||||
if display == "simple":
|
||||
return gr.BarPlot(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
color=None,
|
||||
group=None,
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=['a', 'b'],
|
||||
y_lim=[20, 100],
|
||||
x_title=None,
|
||||
y_title=None,
|
||||
vertical=True,
|
||||
)
|
||||
elif display == "stacked":
|
||||
return gr.BarPlot(
|
||||
barley,
|
||||
x="variety",
|
||||
y="yield",
|
||||
color="site",
|
||||
group=None,
|
||||
title="Barley Yield Data",
|
||||
tooltip=['variety', 'site'],
|
||||
y_lim=None,
|
||||
x_title=None,
|
||||
y_title=None,
|
||||
vertical=True,
|
||||
)
|
||||
elif display == "grouped":
|
||||
return gr.BarPlot(
|
||||
barley.astype({"year": str}),
|
||||
x="year",
|
||||
y="yield",
|
||||
color="year",
|
||||
group="site",
|
||||
title="Barley Yield by Year and Site",
|
||||
tooltip=["yield", "site", "year"],
|
||||
y_lim=None,
|
||||
x_title=None,
|
||||
y_title=None,
|
||||
vertical=True,
|
||||
)
|
||||
elif display == "simple-horizontal":
|
||||
return gr.BarPlot(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
color=None,
|
||||
group=None,
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=['a', 'b'],
|
||||
y_lim=[20, 100],
|
||||
x_title="Variable A",
|
||||
y_title="Variable B",
|
||||
vertical=False,
|
||||
)
|
||||
elif display == "stacked-horizontal":
|
||||
return gr.BarPlot(
|
||||
barley,
|
||||
x="variety",
|
||||
y="yield",
|
||||
color="site",
|
||||
group=None,
|
||||
title="Barley Yield Data",
|
||||
tooltip=['variety', 'site'],
|
||||
y_lim=None,
|
||||
x_title=None,
|
||||
y_title=None,
|
||||
vertical=False,
|
||||
)
|
||||
elif display == "grouped-horizontal":
|
||||
return gr.BarPlot(
|
||||
barley.astype({"year": str}),
|
||||
x="year",
|
||||
y="yield",
|
||||
color="year",
|
||||
group="site",
|
||||
title="Barley Yield by Year and Site",
|
||||
group_title="",
|
||||
tooltip=["yield", "site", "year"],
|
||||
y_lim=None,
|
||||
x_title=None,
|
||||
y_title=None,
|
||||
vertical=False
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as bar_plot:
|
||||
display = gr.Dropdown(
|
||||
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
|
||||
value="simple",
|
||||
label="Type of Bar Plot"
|
||||
temp_by_time = gr.BarPlot(
|
||||
temp_sensor_data,
|
||||
x="time",
|
||||
y="temperature",
|
||||
)
|
||||
plot = gr.BarPlot(show_label=False)
|
||||
display.change(bar_plot_fn, inputs=display, outputs=plot)
|
||||
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)
|
||||
temp_by_time_location = gr.BarPlot(
|
||||
temp_sensor_data,
|
||||
x="time",
|
||||
y="temperature",
|
||||
color="location",
|
||||
)
|
||||
|
||||
time_graphs = [temp_by_time, temp_by_time_location]
|
||||
group_by.change(
|
||||
lambda group: [gr.BarPlot(x_bin=None if group == "None" else group)] * len(time_graphs),
|
||||
group_by,
|
||||
time_graphs
|
||||
)
|
||||
aggregate.change(
|
||||
lambda aggregate: [gr.BarPlot(y_aggregate=aggregate)] * len(time_graphs),
|
||||
aggregate,
|
||||
time_graphs
|
||||
)
|
||||
|
||||
|
||||
def rescale(select: gr.SelectData):
|
||||
return select.index
|
||||
rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])
|
||||
|
||||
for trigger in [apply_btn.click, rescale_evt.then]:
|
||||
trigger(
|
||||
lambda start, end: [gr.BarPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
price_by_cuisine = gr.BarPlot(
|
||||
food_rating_data,
|
||||
x="cuisine",
|
||||
y="price",
|
||||
)
|
||||
with gr.Column(scale=0):
|
||||
gr.Button("Sort $ > $$$").click(lambda: gr.BarPlot(sort="y"), None, price_by_cuisine)
|
||||
gr.Button("Sort $$$ > $").click(lambda: gr.BarPlot(sort="-y"), None, price_by_cuisine)
|
||||
gr.Button("Sort A > Z").click(lambda: gr.BarPlot(sort=["Chinese", "Italian", "Mexican"]), None, price_by_cuisine)
|
||||
|
||||
with gr.Row():
|
||||
price_by_rating = gr.BarPlot(
|
||||
food_rating_data,
|
||||
x="rating",
|
||||
y="price",
|
||||
x_bin=1,
|
||||
)
|
||||
price_by_rating_color = gr.BarPlot(
|
||||
food_rating_data,
|
||||
x="rating",
|
||||
y="price",
|
||||
color="cuisine",
|
||||
x_bin=1,
|
||||
color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bar_plot.launch()
|
||||
bar_plots.launch()
|
||||
|
20
demo/native_plots/data.py
Normal file
20
demo/native_plots/data.py
Normal file
@ -0,0 +1,20 @@
|
||||
import pandas as pd
|
||||
from random import randint, choice, random
|
||||
|
||||
temp_sensor_data = pd.DataFrame(
|
||||
{
|
||||
"time": pd.date_range("2021-01-01", end="2021-01-05", periods=200),
|
||||
"temperature": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)],
|
||||
"humidity": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)],
|
||||
"location": ["indoor", "outdoor"] * 100,
|
||||
}
|
||||
)
|
||||
|
||||
food_rating_data = pd.DataFrame(
|
||||
{
|
||||
"cuisine": [["Italian", "Mexican", "Chinese"][i % 3] for i in range(100)],
|
||||
"rating": [random() * 4 + 0.5 * (i % 3) for i in range(100)],
|
||||
"price": [randint(10, 50) + 4 * (i % 3) for i in range(100)],
|
||||
"wait": [random() for i in range(100)],
|
||||
}
|
||||
)
|
@ -1,82 +1,69 @@
|
||||
import gradio as gr
|
||||
from vega_datasets import data
|
||||
import numpy as np
|
||||
from data import temp_sensor_data, food_rating_data
|
||||
|
||||
stocks = data.stocks()
|
||||
gapminder = data.gapminder()
|
||||
gapminder = gapminder.loc[
|
||||
gapminder.country.isin(["Argentina", "Australia", "Afghanistan"])
|
||||
]
|
||||
climate = data.climate()
|
||||
seattle_weather = data.seattle_weather()
|
||||
with gr.Blocks() as line_plots:
|
||||
with gr.Row():
|
||||
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
|
||||
end = gr.DateTime("2021-01-05 00:00:00", label="End")
|
||||
apply_btn = gr.Button("Apply", scale=0)
|
||||
with gr.Row():
|
||||
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
|
||||
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")
|
||||
|
||||
|
||||
def line_plot_fn(dataset):
|
||||
if dataset == "stocks":
|
||||
return gr.LinePlot(
|
||||
stocks,
|
||||
x="date",
|
||||
y="price",
|
||||
color="symbol",
|
||||
x_lim=None,
|
||||
y_lim=None,
|
||||
stroke_dash=None,
|
||||
tooltip=['date', 'price', 'symbol'],
|
||||
overlay_point=False,
|
||||
title="Stock Prices",
|
||||
stroke_dash_legend_title=None,
|
||||
)
|
||||
elif dataset == "climate":
|
||||
return gr.LinePlot(
|
||||
climate,
|
||||
x="DATE",
|
||||
y="HLY-TEMP-NORMAL",
|
||||
color=None,
|
||||
x_lim=None,
|
||||
y_lim=[250, 500],
|
||||
stroke_dash=None,
|
||||
tooltip=['DATE', 'HLY-TEMP-NORMAL'],
|
||||
overlay_point=False,
|
||||
title="Climate",
|
||||
stroke_dash_legend_title=None,
|
||||
)
|
||||
elif dataset == "seattle_weather":
|
||||
return gr.LinePlot(
|
||||
seattle_weather,
|
||||
x="date",
|
||||
y="temp_min",
|
||||
color=None,
|
||||
x_lim=None,
|
||||
y_lim=None,
|
||||
stroke_dash=None,
|
||||
tooltip=["weather", "date"],
|
||||
overlay_point=True,
|
||||
title="Seattle Weather",
|
||||
stroke_dash_legend_title=None,
|
||||
)
|
||||
elif dataset == "gapminder":
|
||||
return gr.LinePlot(
|
||||
gapminder,
|
||||
x="year",
|
||||
y="life_expect",
|
||||
color="country",
|
||||
x_lim=[1950, 2010],
|
||||
y_lim=None,
|
||||
stroke_dash="cluster",
|
||||
tooltip=['country', 'life_expect'],
|
||||
overlay_point=False,
|
||||
title="Life expectancy for countries",
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as line_plot:
|
||||
dataset = gr.Dropdown(
|
||||
choices=["stocks", "climate", "seattle_weather", "gapminder"],
|
||||
value="stocks",
|
||||
temp_by_time = gr.LinePlot(
|
||||
temp_sensor_data,
|
||||
x="time",
|
||||
y="temperature",
|
||||
)
|
||||
plot = gr.LinePlot()
|
||||
dataset.change(line_plot_fn, inputs=dataset, outputs=plot)
|
||||
line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot)
|
||||
temp_by_time_location = gr.LinePlot(
|
||||
temp_sensor_data,
|
||||
x="time",
|
||||
y="temperature",
|
||||
color="location",
|
||||
)
|
||||
|
||||
time_graphs = [temp_by_time, temp_by_time_location]
|
||||
group_by.change(
|
||||
lambda group: [gr.LinePlot(x_bin=None if group == "None" else group)] * len(time_graphs),
|
||||
group_by,
|
||||
time_graphs
|
||||
)
|
||||
aggregate.change(
|
||||
lambda aggregate: [gr.LinePlot(y_aggregate=aggregate)] * len(time_graphs),
|
||||
aggregate,
|
||||
time_graphs
|
||||
)
|
||||
|
||||
|
||||
def rescale(select: gr.SelectData):
|
||||
return select.index
|
||||
rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])
|
||||
|
||||
for trigger in [apply_btn.click, rescale_evt.then]:
|
||||
trigger(
|
||||
lambda start, end: [gr.LinePlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
|
||||
)
|
||||
|
||||
price_by_cuisine = gr.LinePlot(
|
||||
food_rating_data,
|
||||
x="cuisine",
|
||||
y="price",
|
||||
)
|
||||
with gr.Row():
|
||||
price_by_rating = gr.LinePlot(
|
||||
food_rating_data,
|
||||
x="rating",
|
||||
y="price",
|
||||
)
|
||||
price_by_rating_color = gr.LinePlot(
|
||||
food_rating_data,
|
||||
x="rating",
|
||||
y="price",
|
||||
color="cuisine",
|
||||
color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
line_plot.launch()
|
||||
line_plots.launch()
|
||||
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "from bar_plot_demo import bar_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/data.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plots\n", "from line_plot_demo import line_plots\n", "from bar_plot_demo import bar_plots\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plots.render()\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plots.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plots.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -1,18 +1,18 @@
|
||||
import gradio as gr
|
||||
|
||||
from scatter_plot_demo import scatter_plot
|
||||
from line_plot_demo import line_plot
|
||||
from bar_plot_demo import bar_plot
|
||||
from scatter_plot_demo import scatter_plots
|
||||
from line_plot_demo import line_plots
|
||||
from bar_plot_demo import bar_plots
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Scatter Plot"):
|
||||
scatter_plot.render()
|
||||
with gr.TabItem("Line Plot"):
|
||||
line_plot.render()
|
||||
line_plots.render()
|
||||
with gr.TabItem("Scatter Plot"):
|
||||
scatter_plots.render()
|
||||
with gr.TabItem("Bar Plot"):
|
||||
bar_plot.render()
|
||||
bar_plots.render()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -1,46 +1,71 @@
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from data import temp_sensor_data, food_rating_data
|
||||
|
||||
from vega_datasets import data
|
||||
with gr.Blocks() as scatter_plots:
|
||||
with gr.Row():
|
||||
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
|
||||
end = gr.DateTime("2021-01-05 00:00:00", label="End")
|
||||
apply_btn = gr.Button("Apply", scale=0)
|
||||
with gr.Row():
|
||||
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
|
||||
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")
|
||||
|
||||
cars = data.cars()
|
||||
iris = data.iris()
|
||||
temp_by_time = gr.ScatterPlot(
|
||||
temp_sensor_data,
|
||||
x="time",
|
||||
y="temperature",
|
||||
)
|
||||
temp_by_time_location = gr.ScatterPlot(
|
||||
temp_sensor_data,
|
||||
x="time",
|
||||
y="temperature",
|
||||
color="location",
|
||||
)
|
||||
|
||||
time_graphs = [temp_by_time, temp_by_time_location]
|
||||
group_by.change(
|
||||
lambda group: [gr.ScatterPlot(x_bin=None if group == "None" else group)] * len(time_graphs),
|
||||
group_by,
|
||||
time_graphs
|
||||
)
|
||||
aggregate.change(
|
||||
lambda aggregate: [gr.ScatterPlot(y_aggregate=aggregate)] * len(time_graphs),
|
||||
aggregate,
|
||||
time_graphs
|
||||
)
|
||||
|
||||
|
||||
def scatter_plot_fn(dataset):
|
||||
if dataset == "iris":
|
||||
return gr.ScatterPlot(
|
||||
value=iris,
|
||||
x="petalWidth",
|
||||
y="petalLength",
|
||||
color=None,
|
||||
title="Iris Dataset",
|
||||
x_title="Petal Width",
|
||||
y_title="Petal Length",
|
||||
tooltip=["petalWidth", "petalLength", "species"],
|
||||
caption="",
|
||||
height=600,
|
||||
width=600,
|
||||
# def rescale(select: gr.SelectData):
|
||||
# return select.index
|
||||
# rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])
|
||||
|
||||
# for trigger in [apply_btn.click, rescale_evt.then]:
|
||||
# trigger(
|
||||
# lambda start, end: [gr.ScatterPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
|
||||
# )
|
||||
|
||||
price_by_cuisine = gr.ScatterPlot(
|
||||
food_rating_data,
|
||||
x="cuisine",
|
||||
y="price",
|
||||
)
|
||||
with gr.Row():
|
||||
price_by_rating = gr.ScatterPlot(
|
||||
food_rating_data,
|
||||
x="rating",
|
||||
y="price",
|
||||
color="wait",
|
||||
show_actions_button=True,
|
||||
)
|
||||
else:
|
||||
return gr.ScatterPlot(
|
||||
value=cars,
|
||||
x="Horsepower",
|
||||
y="Miles_per_Gallon",
|
||||
color="Origin",
|
||||
tooltip="Name",
|
||||
title="Car Data",
|
||||
y_title="Miles per Gallon",
|
||||
caption="MPG vs Horsepower of various cars",
|
||||
height=None,
|
||||
width=None,
|
||||
price_by_rating_color = gr.ScatterPlot(
|
||||
food_rating_data,
|
||||
x="rating",
|
||||
y="price",
|
||||
color="cuisine",
|
||||
# color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as scatter_plot:
|
||||
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
|
||||
plot = gr.ScatterPlot(show_label=False)
|
||||
dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
|
||||
scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)
|
||||
|
||||
if __name__ == "__main__":
|
||||
scatter_plot.launch()
|
||||
scatter_plots.launch()
|
||||
|
@ -942,6 +942,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
js: str | None = None,
|
||||
head: str | None = None,
|
||||
fill_height: bool = False,
|
||||
fill_width: bool = False,
|
||||
delete_cache: tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -955,6 +956,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
js: Custom js as a string or path to a js file. The custom js should be in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
|
||||
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
|
||||
fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1.
|
||||
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width. Only applies if this is the outermost `Blocks` in your Gradio app.
|
||||
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
||||
"""
|
||||
self.limiter = None
|
||||
@ -988,6 +990,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
self.show_error = True
|
||||
self.head = head
|
||||
self.fill_height = fill_height
|
||||
self.fill_width = fill_width
|
||||
self.delete_cache = delete_cache
|
||||
if css is not None and os.path.exists(css):
|
||||
with open(css, encoding="utf-8") as css_file:
|
||||
@ -2036,6 +2039,7 @@ Received outputs:
|
||||
),
|
||||
},
|
||||
"fill_height": self.fill_height,
|
||||
"fill_width": self.fill_width,
|
||||
"theme_hash": self.theme_hash,
|
||||
}
|
||||
config.update(self.default_config.get_config()) # type: ignore
|
||||
|
@ -86,6 +86,7 @@ class ChatInterface(Blocks):
|
||||
fill_height: bool = True,
|
||||
delete_cache: tuple[int, int] | None = None,
|
||||
show_progress: Literal["full", "minimal", "hidden"] = "minimal",
|
||||
fill_width: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -116,6 +117,7 @@ class ChatInterface(Blocks):
|
||||
fill_height: If True, the chat interface will expand to the height of window.
|
||||
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
||||
show_progress: whether to show progress animation while running.
|
||||
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
|
||||
"""
|
||||
super().__init__(
|
||||
analytics_enabled=analytics_enabled,
|
||||
@ -126,6 +128,7 @@ class ChatInterface(Blocks):
|
||||
js=js,
|
||||
head=head,
|
||||
fill_height=fill_height,
|
||||
fill_width=fill_width,
|
||||
delete_cache=delete_cache,
|
||||
)
|
||||
self.type: Literal["messages", "tuples"] = type
|
||||
|
@ -87,7 +87,7 @@ OVERRIDES = {
|
||||
"Plot": ComponentFiles(template="Plot", demo_code=static_only_demo_code),
|
||||
"BarPlot": ComponentFiles(
|
||||
template="BarPlot",
|
||||
python_file_name="bar_plot.py",
|
||||
python_file_name="native_plot.py",
|
||||
js_dir="plot",
|
||||
demo_code=static_only_demo_code,
|
||||
),
|
||||
@ -121,7 +121,7 @@ OVERRIDES = {
|
||||
),
|
||||
"LinePlot": ComponentFiles(
|
||||
template="LinePlot",
|
||||
python_file_name="line_plot.py",
|
||||
python_file_name="native_plot.py",
|
||||
js_dir="plot",
|
||||
demo_code=static_only_demo_code,
|
||||
),
|
||||
@ -139,7 +139,7 @@ OVERRIDES = {
|
||||
),
|
||||
"ScatterPlot": ComponentFiles(
|
||||
template="ScatterPlot",
|
||||
python_file_name="scatter_plot.py",
|
||||
python_file_name="native_plot.py",
|
||||
js_dir="plot",
|
||||
demo_code=static_only_demo_code,
|
||||
),
|
||||
|
@ -1,6 +1,5 @@
|
||||
from gradio.components.annotated_image import AnnotatedImage
|
||||
from gradio.components.audio import Audio
|
||||
from gradio.components.bar_plot import BarPlot
|
||||
from gradio.components.base import (
|
||||
Component,
|
||||
FormComponent,
|
||||
@ -33,17 +32,16 @@ from gradio.components.image import Image
|
||||
from gradio.components.image_editor import ImageEditor
|
||||
from gradio.components.json_component import JSON
|
||||
from gradio.components.label import Label
|
||||
from gradio.components.line_plot import LinePlot
|
||||
from gradio.components.login_button import LoginButton
|
||||
from gradio.components.logout_button import LogoutButton
|
||||
from gradio.components.markdown import Markdown
|
||||
from gradio.components.model3d import Model3D
|
||||
from gradio.components.multimodal_textbox import MultimodalTextbox
|
||||
from gradio.components.native_plot import BarPlot, LinePlot, ScatterPlot
|
||||
from gradio.components.number import Number
|
||||
from gradio.components.paramviewer import ParamViewer
|
||||
from gradio.components.plot import Plot
|
||||
from gradio.components.radio import Radio
|
||||
from gradio.components.scatter_plot import ScatterPlot
|
||||
from gradio.components.slider import Slider
|
||||
from gradio.components.state import State
|
||||
from gradio.components.textbox import Textbox
|
||||
|
263
gradio/components/native_plot.py
Normal file
263
gradio/components/native_plot.py
Normal file
@ -0,0 +1,263 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
from gradio_client.documentation import document
|
||||
|
||||
from gradio.components.base import Component
|
||||
from gradio.data_classes import GradioModel
|
||||
from gradio.events import Events
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Timer
|
||||
|
||||
|
||||
class PlotData(GradioModel):
|
||||
columns: List[str]
|
||||
data: List[List[Any]]
|
||||
datatypes: Dict[str, Literal["quantitative", "nominal", "temporal"]]
|
||||
mark: str
|
||||
|
||||
|
||||
class NativePlot(Component):
|
||||
"""
|
||||
Creates a native Gradio plot component to display data from a pandas DataFrame. Supports interactivity and updates.
|
||||
|
||||
Demos: native_plots
|
||||
"""
|
||||
|
||||
EVENTS = [Events.select]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: pd.DataFrame | Callable | None = None,
|
||||
x: str | None = None,
|
||||
y: str | None = None,
|
||||
*,
|
||||
color: str | None = None,
|
||||
title: str | None = None,
|
||||
x_title: str | None = None,
|
||||
y_title: str | None = None,
|
||||
color_title: str | None = None,
|
||||
x_bin: str | float | None = None,
|
||||
y_aggregate: Literal["sum", "mean", "median", "min", "max", "count"]
|
||||
| None = None,
|
||||
color_map: dict[str, str] | None = None,
|
||||
x_lim: list[float] | None = None,
|
||||
y_lim: list[float] | None = None,
|
||||
caption: str | None = None,
|
||||
sort: Literal["x", "y", "-x", "-y"] | list[str] | None = None,
|
||||
height: int | None = None,
|
||||
label: str | None = None,
|
||||
show_label: bool | None = None,
|
||||
container: bool = True,
|
||||
scale: int | None = None,
|
||||
min_width: int = 160,
|
||||
every: Timer | float | None = None,
|
||||
inputs: Component | Sequence[Component] | AbstractSet[Component] | None = None,
|
||||
visible: bool = True,
|
||||
elem_id: str | None = None,
|
||||
elem_classes: list[str] | str | None = None,
|
||||
render: bool = True,
|
||||
key: int | str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
value: The pandas dataframe containing the data to display in the plot.
|
||||
x: Column corresponding to the x axis. Column can be numeric, datetime, or string/category.
|
||||
y: Column corresponding to the y axis. Column must be numeric.
|
||||
color: Column corresponding to series, visualized by color. Column must be string/category.
|
||||
title: The title to display on top of the chart.
|
||||
x_title: The title given to the x axis. By default, uses the value of the x parameter.
|
||||
y_title: The title given to the y axis. By default, uses the value of the y parameter.
|
||||
color_title: The title given to the color legend. By default, uses the value of color parameter.
|
||||
x_bin: Grouping used to cluster x values. If x column is numeric, should be number to bin the x values. If x column is datetime, should be string such as "1h", "15m", "10s", using "s", "m", "h", "d" suffixes.
|
||||
y_aggregate: Aggregation function used to aggregate y values, used if x_bin is provided or x is a string/category. Must be one of "sum", "mean", "median", "min", "max".
|
||||
color_map: Mapping of series to color names or codes. For example, {"success": "green", "fail": "#FF8888"}.
|
||||
height: The height of the plot in pixels.
|
||||
x_lim: A tuple or list containing the limits for the x-axis, specified as [x_min, x_max]. If x column is datetime type, x_lim should be timestamps.
|
||||
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
|
||||
caption: The (optional) caption to display below the plot.
|
||||
sort: The sorting order of the x values, if x column is type string/category. Can be "x", "y", "-x", "-y", or list of strings that represent the order of the categories.
|
||||
height: The height of the plot in pixels.
|
||||
label: The (optional) label to display on the top left corner of the plot.
|
||||
show_label: Whether the label should be displayed.
|
||||
container: If True, will place the component in a container - providing some extra padding around the border.
|
||||
scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.
|
||||
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
|
||||
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
|
||||
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
|
||||
visible: Whether the plot should be visible.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
|
||||
key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.
|
||||
"""
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.color = color
|
||||
self.title = title
|
||||
self.x_title = x_title
|
||||
self.y_title = y_title
|
||||
self.color_title = color_title
|
||||
self.x_bin = x_bin
|
||||
self.y_aggregate = y_aggregate
|
||||
self.color_map = color_map
|
||||
self.x_lim = x_lim
|
||||
self.y_lim = y_lim
|
||||
self.caption = caption
|
||||
self.sort = sort
|
||||
self.height = height
|
||||
|
||||
if label is None and show_label is None:
|
||||
show_label = False
|
||||
super().__init__(
|
||||
value=value,
|
||||
label=label,
|
||||
show_label=show_label,
|
||||
container=container,
|
||||
scale=scale,
|
||||
min_width=min_width,
|
||||
visible=visible,
|
||||
elem_id=elem_id,
|
||||
elem_classes=elem_classes,
|
||||
render=render,
|
||||
key=key,
|
||||
every=every,
|
||||
inputs=inputs,
|
||||
)
|
||||
for key, val in kwargs.items():
|
||||
if key == "color_legend_title":
|
||||
self.color_title = val
|
||||
if key in [
|
||||
"stroke_dash",
|
||||
"overlay_point",
|
||||
"tooltip",
|
||||
"x_label_angle",
|
||||
"y_label_angle",
|
||||
"interactive",
|
||||
"show_actions_button",
|
||||
"color_legend_title",
|
||||
"width",
|
||||
]:
|
||||
warnings.warn(
|
||||
f"Argument '{key}' has been deprecated.", DeprecationWarning
|
||||
)
|
||||
|
||||
def get_block_name(self) -> str:
|
||||
return "nativeplot"
|
||||
|
||||
def get_mark(self) -> str:
|
||||
return "native"
|
||||
|
||||
def preprocess(self, payload: PlotData | None) -> PlotData | None:
|
||||
"""
|
||||
Parameters:
|
||||
payload: The data to display in a line plot.
|
||||
Returns:
|
||||
The data to display in a line plot.
|
||||
"""
|
||||
return payload
|
||||
|
||||
def postprocess(self, value: pd.DataFrame | dict | None) -> PlotData | None:
|
||||
"""
|
||||
Parameters:
|
||||
value: Expects a pandas DataFrame containing the data to display in the line plot. The DataFrame should contain at least two columns, one for the x-axis (corresponding to this component's `x` argument) and one for the y-axis (corresponding to `y`).
|
||||
Returns:
|
||||
The data to display in a line plot, in the form of an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "line").
|
||||
"""
|
||||
# if None or update
|
||||
if value is None or isinstance(value, dict):
|
||||
return value
|
||||
|
||||
def get_simplified_type(dtype):
|
||||
if pd.api.types.is_numeric_dtype(dtype):
|
||||
return "quantitative"
|
||||
elif pd.api.types.is_string_dtype(
|
||||
dtype
|
||||
) or pd.api.types.is_categorical_dtype(dtype):
|
||||
return "nominal"
|
||||
elif pd.api.types.is_datetime64_any_dtype(dtype):
|
||||
return "temporal"
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {dtype}")
|
||||
|
||||
split_json = json.loads(value.to_json(orient="split", date_unit="ms"))
|
||||
datatypes = {
|
||||
col: get_simplified_type(value[col].dtype) for col in value.columns
|
||||
}
|
||||
return PlotData(
|
||||
columns=split_json["columns"],
|
||||
data=split_json["data"],
|
||||
datatypes=datatypes,
|
||||
mark=self.get_mark(),
|
||||
)
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return None
|
||||
|
||||
def example_value(self) -> Any:
|
||||
import pandas as pd
|
||||
|
||||
return pd.DataFrame({self.x: [1, 2, 3], self.y: [4, 5, 6]})
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": {}, "description": "any valid json"}
|
||||
|
||||
|
||||
@document()
|
||||
class BarPlot(NativePlot):
|
||||
"""
|
||||
Creates a bar plot component to display data from a pandas DataFrame.
|
||||
|
||||
Demos: native_plots
|
||||
"""
|
||||
|
||||
def get_block_name(self) -> str:
|
||||
return "nativeplot"
|
||||
|
||||
def get_mark(self) -> str:
|
||||
return "bar"
|
||||
|
||||
|
||||
@document()
|
||||
class LinePlot(NativePlot):
|
||||
"""
|
||||
Creates a line plot component to display data from a pandas DataFrame.
|
||||
|
||||
Demos: native_plots
|
||||
"""
|
||||
|
||||
def get_block_name(self) -> str:
|
||||
return "nativeplot"
|
||||
|
||||
def get_mark(self) -> str:
|
||||
return "line"
|
||||
|
||||
|
||||
@document()
|
||||
class ScatterPlot(NativePlot):
|
||||
"""
|
||||
Creates a scatter plot component to display data from a pandas DataFrame.
|
||||
|
||||
Demos: native_plots
|
||||
"""
|
||||
|
||||
def get_block_name(self) -> str:
|
||||
return "nativeplot"
|
||||
|
||||
def get_mark(self) -> str:
|
||||
return "point"
|
@ -315,6 +315,7 @@ class BlocksConfigDict(TypedDict):
|
||||
protocol: Literal["ws", "sse", "sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
|
||||
body_css: BodyCSS
|
||||
fill_height: bool
|
||||
fill_width: bool
|
||||
theme_hash: str
|
||||
layout: NotRequired[Layout]
|
||||
dependencies: NotRequired[list[dict[str, Any]]]
|
||||
|
@ -133,6 +133,7 @@ class Interface(Blocks):
|
||||
delete_cache: tuple[int, int] | None = None,
|
||||
show_progress: Literal["full", "minimal", "hidden"] = "full",
|
||||
example_labels: list[str] | None = None,
|
||||
fill_width: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -170,6 +171,7 @@ class Interface(Blocks):
|
||||
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
||||
show_progress: whether to show progress animation while running. Has no effect if the interface is `live`.
|
||||
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
|
||||
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
|
||||
"""
|
||||
super().__init__(
|
||||
analytics_enabled=analytics_enabled,
|
||||
@ -180,6 +182,7 @@ class Interface(Blocks):
|
||||
js=js,
|
||||
head=head,
|
||||
delete_cache=delete_cache,
|
||||
fill_width=fill_width,
|
||||
**kwargs,
|
||||
)
|
||||
self.api_name: str | Literal[False] | None = api_name
|
||||
|
@ -8,13 +8,16 @@ import gradio as gr
|
||||
data = {"data": {}}
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# Monitoring Dashboard")
|
||||
timer = gr.Timer(5)
|
||||
with gr.Row():
|
||||
selected_function = gr.Dropdown(
|
||||
start = gr.DateTime("now - 24h", label="Start Time")
|
||||
end = gr.DateTime("now", label="End Time")
|
||||
selected_fn = gr.Dropdown(
|
||||
["All"],
|
||||
value="All",
|
||||
label="Endpoint",
|
||||
info="Select the function to see analytics for, or 'All' for aggregate.",
|
||||
scale=2,
|
||||
)
|
||||
demo.load(
|
||||
lambda: gr.Dropdown(
|
||||
@ -22,70 +25,78 @@ with gr.Blocks() as demo:
|
||||
+ list({row["function"] for row in data["data"].values()}) # type: ignore
|
||||
),
|
||||
None,
|
||||
selected_function,
|
||||
)
|
||||
timespan = gr.Dropdown(
|
||||
["All Time", "24 hours", "1 hours", "10 minutes"],
|
||||
value="All Time",
|
||||
label="Timespan",
|
||||
info="Duration to see data for.",
|
||||
selected_fn,
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
unique_users = gr.Label(label="Unique Users")
|
||||
unique_requests = gr.Label(label="Unique Requests")
|
||||
total_requests = gr.Label(label="Total Requests")
|
||||
process_time = gr.Label(label="Avg Process Time")
|
||||
|
||||
plot = gr.BarPlot(
|
||||
x="time",
|
||||
y="count",
|
||||
y="function",
|
||||
color="status",
|
||||
title="Requests over Time",
|
||||
y_title="Requests",
|
||||
width=600,
|
||||
x_bin="1m",
|
||||
y_aggregate="count",
|
||||
color_map={
|
||||
"success": "#22c55e",
|
||||
"failure": "#ef4444",
|
||||
"pending": "#eab308",
|
||||
"queued": "#3b82f6",
|
||||
},
|
||||
)
|
||||
|
||||
@gr.on(
|
||||
[demo.load, selected_function.change, timespan.change],
|
||||
inputs=[selected_function, timespan],
|
||||
outputs=[unique_users, unique_requests, process_time, plot],
|
||||
[demo.load, timer.tick, start.change, end.change, selected_fn.change],
|
||||
inputs=[start, end, selected_fn],
|
||||
outputs=[plot, unique_users, total_requests, process_time],
|
||||
)
|
||||
def load_dfs(function, timespan):
|
||||
df = pd.DataFrame(data["data"].values())
|
||||
if df.empty:
|
||||
return 0, 0, 0, gr.skip()
|
||||
def gen_plot(start, end, selected_fn):
|
||||
df = pd.DataFrame(list(data["data"].values()))
|
||||
if selected_fn != "All":
|
||||
df = df[df["function"] == selected_fn]
|
||||
df = df[(df["time"] >= start) & (df["time"] <= end)]
|
||||
df["time"] = pd.to_datetime(df["time"], unit="s")
|
||||
df_filtered = df if function == "All" else df[df["function"] == function]
|
||||
if timespan != "All Time":
|
||||
df_filtered = df_filtered[
|
||||
df_filtered["time"] > pd.Timestamp.now() - pd.Timedelta(timespan) # type: ignore
|
||||
]
|
||||
|
||||
df_filtered["time"] = df_filtered["time"].dt.floor("min") # type: ignore
|
||||
plot = df_filtered.groupby(["time", "status"]).size().reset_index(name="count") # type: ignore
|
||||
mean_process_time_for_success = df_filtered[df_filtered["status"] == "success"][
|
||||
"process_time"
|
||||
].mean()
|
||||
unique_users = len(df["session_hash"].unique())
|
||||
total_requests = len(df)
|
||||
process_time = round(df["process_time"].mean(), 2)
|
||||
|
||||
duration = end - start
|
||||
x_bin = (
|
||||
"1h"
|
||||
if duration >= 60 * 60 * 24
|
||||
else "15m"
|
||||
if duration >= 60 * 60 * 3
|
||||
else "1m"
|
||||
)
|
||||
df = df.drop(columns=["session_hash"])
|
||||
return (
|
||||
df_filtered["session_hash"].nunique(), # type: ignore
|
||||
df_filtered.shape[0],
|
||||
round(mean_process_time_for_success, 2),
|
||||
plot,
|
||||
gr.BarPlot(value=df, x_bin=x_bin, x_lim=[start, end]),
|
||||
unique_users,
|
||||
total_requests,
|
||||
process_time,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data["data"] = {
|
||||
random.randint(0, 1000000): {
|
||||
"time": time.time() - random.randint(0, 60 * 60 * 24 * 3),
|
||||
data["data"] = {}
|
||||
for _ in range(random.randint(300, 500)):
|
||||
timedelta = random.randint(0, 60 * 60 * 24 * 3)
|
||||
data["data"][random.randint(1, 100000)] = {
|
||||
"time": time.time() - timedelta,
|
||||
"status": random.choice(
|
||||
["success", "success", "failure", "pending", "queued"]
|
||||
["success", "success", "failure"]
|
||||
if timedelta > 30 * 60
|
||||
else ["queued", "pending"]
|
||||
),
|
||||
"function": random.choice(["predict", "chat", "chat"]),
|
||||
"process_time": random.randint(0, 10),
|
||||
"session_hash": str(random.randint(0, 4)),
|
||||
}
|
||||
for r in range(random.randint(100, 200))
|
||||
}
|
||||
|
||||
demo.launch()
|
||||
|
@ -563,7 +563,8 @@ def get_all_components() -> list[type[Component] | type[BlockContext]]:
|
||||
return [
|
||||
c
|
||||
for c in subclasses
|
||||
if c.__name__ not in ["ChatInterface", "Interface", "Blocks", "TabbedInterface"]
|
||||
if c.__name__
|
||||
not in ["ChatInterface", "Interface", "Blocks", "TabbedInterface", "NativePlot"]
|
||||
]
|
||||
|
||||
|
||||
|
@ -57,6 +57,7 @@
|
||||
"@gradio/markdown": "workspace:^",
|
||||
"@gradio/model3d": "workspace:^",
|
||||
"@gradio/multimodaltextbox": "workspace:^",
|
||||
"@gradio/nativeplot": "workspace:^",
|
||||
"@gradio/number": "workspace:^",
|
||||
"@gradio/paramviewer": "workspace:^",
|
||||
"@gradio/plot": "workspace:^",
|
||||
|
@ -4,6 +4,7 @@
|
||||
export let wrapper: HTMLDivElement;
|
||||
export let version: string;
|
||||
export let initial_height: string;
|
||||
export let fill_width: boolean;
|
||||
export let is_embed: boolean;
|
||||
|
||||
export let space: string | null;
|
||||
@ -15,6 +16,7 @@
|
||||
<div
|
||||
bind:this={wrapper}
|
||||
class:app={!display && !is_embed}
|
||||
class:fill_width
|
||||
class:embed-container={display}
|
||||
class:with-info={info}
|
||||
class="gradio-container gradio-container-{version}"
|
||||
@ -87,27 +89,27 @@
|
||||
}
|
||||
|
||||
@media (--screen-sm) {
|
||||
.app {
|
||||
.app:not(.fill_width) {
|
||||
max-width: 640px;
|
||||
}
|
||||
}
|
||||
@media (--screen-md) {
|
||||
.app {
|
||||
.app:not(.fill_width) {
|
||||
max-width: 768px;
|
||||
}
|
||||
}
|
||||
@media (--screen-lg) {
|
||||
.app {
|
||||
.app:not(.fill_width) {
|
||||
max-width: 1024px;
|
||||
}
|
||||
}
|
||||
@media (--screen-xl) {
|
||||
.app {
|
||||
.app:not(.fill_width) {
|
||||
max-width: 1280px;
|
||||
}
|
||||
}
|
||||
@media (--screen-xxl) {
|
||||
.app {
|
||||
.app:not(.fill_width) {
|
||||
max-width: 1536px;
|
||||
}
|
||||
}
|
||||
|
@ -29,6 +29,7 @@
|
||||
path: string;
|
||||
app_id?: string;
|
||||
fill_height?: boolean;
|
||||
fill_width?: boolean;
|
||||
theme_hash?: number;
|
||||
username: string | null;
|
||||
}
|
||||
@ -412,6 +413,7 @@
|
||||
{initial_height}
|
||||
{space}
|
||||
loaded={loader_status === "complete"}
|
||||
fill_width={config?.fill_width || false}
|
||||
bind:wrapper
|
||||
>
|
||||
{#if (loader_status === "pending" || loader_status === "error") && !(config && config?.auth_required)}
|
||||
|
@ -25,6 +25,8 @@
|
||||
export let include_time = true;
|
||||
$: if (value !== old_value) {
|
||||
old_value = value;
|
||||
entered_value = value;
|
||||
datevalue = value;
|
||||
gradio.dispatch("change");
|
||||
}
|
||||
|
||||
@ -49,7 +51,7 @@
|
||||
|
||||
let entered_value = value;
|
||||
let datetime: HTMLInputElement;
|
||||
let datevalue = "";
|
||||
let datevalue = value;
|
||||
|
||||
const date_is_valid_format = (date: string): boolean => {
|
||||
if (date === "") return false;
|
||||
@ -156,7 +158,6 @@
|
||||
flex-shrink: 1;
|
||||
display: flex;
|
||||
position: relative;
|
||||
box-shadow: var(--input-shadow);
|
||||
background: var(--input-background-fill);
|
||||
}
|
||||
.timebox :global(svg) {
|
||||
@ -175,6 +176,7 @@
|
||||
border-right: none;
|
||||
border-top-left-radius: var(--input-radius);
|
||||
border-bottom-left-radius: var(--input-radius);
|
||||
box-shadow: var(--input-shadow);
|
||||
}
|
||||
.time.invalid {
|
||||
color: var(--body-text-color-subdued);
|
||||
@ -193,6 +195,13 @@
|
||||
border-top-right-radius: var(--input-radius);
|
||||
border-bottom-right-radius: var(--input-radius);
|
||||
padding: var(--size-2);
|
||||
border: var(--input-border-width) solid var(--input-border-color);
|
||||
}
|
||||
.calendar:hover {
|
||||
background: var(--button-secondary-background-fill-hover);
|
||||
}
|
||||
.calendar:active {
|
||||
box-shadow: var(--button-shadow-active);
|
||||
}
|
||||
.datetime {
|
||||
width: 0px;
|
||||
|
1
js/nativeplot/CHANGELOG.md
Normal file
1
js/nativeplot/CHANGELOG.md
Normal file
@ -0,0 +1 @@
|
||||
# @gradio/nativeplot
|
11
js/nativeplot/Example.svelte
Normal file
11
js/nativeplot/Example.svelte
Normal file
@ -0,0 +1,11 @@
|
||||
<script lang="ts">
|
||||
export let title: string | null;
|
||||
export let x: string;
|
||||
export let y: string;
|
||||
</script>
|
||||
|
||||
{#if title}
|
||||
{title}
|
||||
{:else}
|
||||
{x} x {y}
|
||||
{/if}
|
505
js/nativeplot/Index.svelte
Normal file
505
js/nativeplot/Index.svelte
Normal file
@ -0,0 +1,505 @@
|
||||
<script lang="ts">
|
||||
import type { Gradio, SelectData } from "@gradio/utils";
|
||||
import { BlockTitle } from "@gradio/atoms";
|
||||
import { Block } from "@gradio/atoms";
|
||||
import { StatusTracker } from "@gradio/statustracker";
|
||||
import type { LoadingStatus } from "@gradio/statustracker";
|
||||
import { onMount } from "svelte";
|
||||
|
||||
import type { TopLevelSpec as Spec } from "vega-lite";
|
||||
import vegaEmbed from "vega-embed";
|
||||
import type { View } from "vega";
|
||||
import { LineChart as LabelIcon } from "@gradio/icons";
|
||||
import { Empty } from "@gradio/atoms";
|
||||
|
||||
interface PlotData {
|
||||
columns: string[];
|
||||
data: [string | number][];
|
||||
datatypes: Record<string, "quantitative" | "temporal" | "nominal">;
|
||||
mark: "line" | "point" | "bar";
|
||||
}
|
||||
export let value: PlotData | null;
|
||||
export let x: string;
|
||||
export let y: string;
|
||||
export let color: string | null = null;
|
||||
$: unique_colors =
|
||||
color && value && value.datatypes[color] === "nominal"
|
||||
? Array.from(new Set(_data.map((d) => d[color])))
|
||||
: [];
|
||||
|
||||
export let title: string | null = null;
|
||||
export let x_title: string | null = null;
|
||||
export let y_title: string | null = null;
|
||||
export let color_title: string | null = null;
|
||||
export let x_bin: string | number | null = null;
|
||||
export let y_aggregate:
|
||||
| "sum"
|
||||
| "mean"
|
||||
| "median"
|
||||
| "min"
|
||||
| "max"
|
||||
| undefined = undefined;
|
||||
export let color_map: Record<string, string> | null = null;
|
||||
export let x_lim: [number, number] | null = null;
|
||||
export let y_lim: [number, number] | null = null;
|
||||
export let caption: string | null = null;
|
||||
export let sort: "x" | "y" | "-x" | "-y" | string[] | null = null;
|
||||
function reformat_sort(
|
||||
_sort: typeof sort
|
||||
):
|
||||
| string
|
||||
| "ascending"
|
||||
| "descending"
|
||||
| { field: string; order: "ascending" | "descending" }
|
||||
| string[]
|
||||
| undefined {
|
||||
if (_sort === "x") {
|
||||
return "ascending";
|
||||
} else if (_sort === "-x") {
|
||||
return "descending";
|
||||
} else if (_sort === "y") {
|
||||
return { field: y, order: "ascending" };
|
||||
} else if (_sort === "-y") {
|
||||
return { field: y, order: "descending" };
|
||||
} else if (_sort === null) {
|
||||
return undefined;
|
||||
} else if (Array.isArray(_sort)) {
|
||||
return _sort;
|
||||
}
|
||||
}
|
||||
$: _sort = reformat_sort(sort);
|
||||
export let _selectable = false;
|
||||
export let target: HTMLDivElement;
|
||||
let _data: {
|
||||
[x: string]: string | number;
|
||||
}[];
|
||||
export let gradio: Gradio<{
|
||||
select: SelectData;
|
||||
clear_status: LoadingStatus;
|
||||
}>;
|
||||
|
||||
$: x_temporal = value && value.datatypes[x] === "temporal";
|
||||
$: _x_lim = x_lim && x_temporal ? [x_lim[0] * 1000, x_lim[1] * 1000] : x_lim;
|
||||
let _x_bin: number | undefined;
|
||||
let mouse_down_on_chart = false;
|
||||
const SUFFIX_DURATION: Record<string, number> = {
|
||||
s: 1,
|
||||
m: 60,
|
||||
h: 60 * 60,
|
||||
d: 24 * 60 * 60
|
||||
};
|
||||
$: _x_bin = x_bin
|
||||
? typeof x_bin === "string"
|
||||
? 1000 *
|
||||
parseInt(x_bin.substring(0, x_bin.length - 1)) *
|
||||
SUFFIX_DURATION[x_bin[x_bin.length - 1]]
|
||||
: x_bin
|
||||
: undefined;
|
||||
let _y_aggregate: typeof y_aggregate;
|
||||
let aggregating: boolean;
|
||||
$: {
|
||||
if (value) {
|
||||
if (value.mark === "point") {
|
||||
aggregating = _x_bin !== undefined;
|
||||
_y_aggregate = y_aggregate || aggregating ? "sum" : undefined;
|
||||
} else {
|
||||
aggregating = _x_bin !== undefined || value.datatypes[x] === "nominal";
|
||||
_y_aggregate = y_aggregate ? y_aggregate : "sum";
|
||||
}
|
||||
}
|
||||
}
|
||||
function reformat_data(data: PlotData): {
|
||||
[x: string]: string | number;
|
||||
}[] {
|
||||
let x_index = data.columns.indexOf(x);
|
||||
let y_index = data.columns.indexOf(y);
|
||||
let color_index = color ? data.columns.indexOf(color) : null;
|
||||
return data.data.map((row) => {
|
||||
const obj = {
|
||||
[x]: row[x_index],
|
||||
[y]: row[y_index]
|
||||
};
|
||||
if (color && color_index !== null) {
|
||||
obj[color] = row[color_index];
|
||||
}
|
||||
return obj;
|
||||
});
|
||||
}
|
||||
$: _data = value ? reformat_data(value) : [];
|
||||
|
||||
let chart_element: HTMLDivElement;
|
||||
let computed_style = window.getComputedStyle(target);
|
||||
let view: View;
|
||||
let mounted = false;
|
||||
let old_width: number;
|
||||
|
||||
function load_chart(): void {
|
||||
if (view) {
|
||||
view.finalize();
|
||||
}
|
||||
if (!value) return;
|
||||
old_width = chart_element.offsetWidth;
|
||||
const spec = create_vega_lite_spec();
|
||||
if (!spec) return;
|
||||
let resizeObserver = new ResizeObserver(() => {
|
||||
if (
|
||||
old_width === 0 &&
|
||||
chart_element.offsetWidth !== 0 &&
|
||||
value.datatypes[x] === "nominal"
|
||||
) {
|
||||
// a bug where when a nominal chart is first loaded, the width is 0, it doesn't resize
|
||||
load_chart();
|
||||
} else {
|
||||
view.signal("width", chart_element.offsetWidth).run();
|
||||
}
|
||||
});
|
||||
vegaEmbed(chart_element, spec, { actions: false }).then(function (result) {
|
||||
view = result.view;
|
||||
resizeObserver.observe(chart_element);
|
||||
var debounceTimeout: NodeJS.Timeout;
|
||||
if (_selectable) {
|
||||
view.addSignalListener("brush", function (_, value) {
|
||||
if (Object.keys(value).length === 0) return;
|
||||
clearTimeout(debounceTimeout);
|
||||
let range: [number, number] = value[Object.keys(value)[0]];
|
||||
if (x_temporal) {
|
||||
range = [range[0] / 1000, range[1] / 1000];
|
||||
}
|
||||
let callback = (): void => {
|
||||
gradio.dispatch("select", {
|
||||
value: range,
|
||||
index: range,
|
||||
selected: true
|
||||
});
|
||||
};
|
||||
if (mouse_down_on_chart) {
|
||||
release_callback = callback;
|
||||
} else {
|
||||
debounceTimeout = setTimeout(function () {
|
||||
gradio.dispatch("select", {
|
||||
value: range,
|
||||
index: range,
|
||||
selected: true
|
||||
});
|
||||
}, 250);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let release_callback: (() => void) | null = null;
|
||||
onMount(() => {
|
||||
mounted = true;
|
||||
chart_element.addEventListener("mousedown", () => {
|
||||
mouse_down_on_chart = true;
|
||||
});
|
||||
chart_element.addEventListener("mouseup", () => {
|
||||
mouse_down_on_chart = false;
|
||||
if (release_callback) {
|
||||
release_callback();
|
||||
release_callback = null;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
$: title,
|
||||
x_title,
|
||||
y_title,
|
||||
color_title,
|
||||
x,
|
||||
y,
|
||||
color,
|
||||
x_bin,
|
||||
_y_aggregate,
|
||||
color_map,
|
||||
x_lim,
|
||||
y_lim,
|
||||
caption,
|
||||
sort,
|
||||
mounted && load_chart();
|
||||
|
||||
function create_vega_lite_spec(): Spec | null {
|
||||
if (!value) return null;
|
||||
let accent_color = computed_style.getPropertyValue("--color-accent");
|
||||
let body_text_color = computed_style.getPropertyValue("--body-text-color");
|
||||
let borderColorPrimary = computed_style.getPropertyValue(
|
||||
"--border-color-primary"
|
||||
);
|
||||
let font_family = computed_style.fontFamily;
|
||||
let title_weight = computed_style.getPropertyValue(
|
||||
"--block-title-text-weight"
|
||||
) as
|
||||
| "bold"
|
||||
| "normal"
|
||||
| 100
|
||||
| 200
|
||||
| 300
|
||||
| 400
|
||||
| 500
|
||||
| 600
|
||||
| 700
|
||||
| 800
|
||||
| 900;
|
||||
const font_to_px_val = (font: string): number => {
|
||||
return font.endsWith("px") ? parseFloat(font.slice(0, -2)) : 12;
|
||||
};
|
||||
let text_size_md = font_to_px_val(
|
||||
computed_style.getPropertyValue("--text-md")
|
||||
);
|
||||
let text_size_sm = font_to_px_val(
|
||||
computed_style.getPropertyValue("--text-sm")
|
||||
);
|
||||
|
||||
return {
|
||||
$schema: "https://vega.github.io/schema/vega-lite/v5.17.0.json",
|
||||
background: "transparent",
|
||||
config: {
|
||||
autosize: { type: "fit", contains: "padding" },
|
||||
axis: {
|
||||
labelFont: font_family,
|
||||
labelColor: body_text_color,
|
||||
titleFont: font_family,
|
||||
titleColor: body_text_color,
|
||||
titlePadding: 8,
|
||||
tickColor: borderColorPrimary,
|
||||
labelFontSize: text_size_sm,
|
||||
gridColor: borderColorPrimary,
|
||||
titleFontWeight: "normal",
|
||||
titleFontSize: text_size_sm,
|
||||
labelFontWeight: "normal",
|
||||
domain: false,
|
||||
labelAngle: 0
|
||||
},
|
||||
legend: {
|
||||
labelColor: body_text_color,
|
||||
labelFont: font_family,
|
||||
titleColor: body_text_color,
|
||||
titleFont: font_family,
|
||||
titleFontWeight: "normal",
|
||||
titleFontSize: text_size_sm,
|
||||
labelFontWeight: "normal",
|
||||
offset: 2
|
||||
},
|
||||
title: {
|
||||
color: body_text_color,
|
||||
font: font_family,
|
||||
fontSize: text_size_md,
|
||||
fontWeight: title_weight,
|
||||
anchor: "middle"
|
||||
},
|
||||
view: { stroke: borderColorPrimary },
|
||||
mark: {
|
||||
stroke: value.mark !== "bar" ? accent_color : undefined,
|
||||
fill: value.mark === "bar" ? accent_color : undefined,
|
||||
cursor: "crosshair"
|
||||
}
|
||||
},
|
||||
data: { name: "data" },
|
||||
datasets: {
|
||||
data: _data
|
||||
},
|
||||
layer: ["plot", ...(value.mark === "line" ? ["hover"] : [])].map(
|
||||
(mode) => {
|
||||
return {
|
||||
encoding: {
|
||||
size:
|
||||
value.mark === "line"
|
||||
? mode == "plot"
|
||||
? {
|
||||
condition: {
|
||||
empty: false,
|
||||
param: "hoverPlot",
|
||||
value: 3
|
||||
},
|
||||
value: 2
|
||||
}
|
||||
: {
|
||||
condition: { empty: false, param: "hover", value: 100 },
|
||||
value: 0
|
||||
}
|
||||
: undefined,
|
||||
opacity:
|
||||
mode === "plot"
|
||||
? undefined
|
||||
: {
|
||||
condition: { empty: false, param: "hover", value: 1 },
|
||||
value: 0
|
||||
},
|
||||
x: {
|
||||
axis: {},
|
||||
field: x,
|
||||
title: x_title || x,
|
||||
type: value.datatypes[x],
|
||||
scale: _x_lim ? { domain: _x_lim } : undefined,
|
||||
bin: _x_bin ? { step: _x_bin } : undefined,
|
||||
sort: _sort
|
||||
},
|
||||
y: {
|
||||
axis: {},
|
||||
field: y,
|
||||
title: y_title || y,
|
||||
type: value.datatypes[y],
|
||||
scale: y_lim ? { domain: y_lim } : undefined,
|
||||
aggregate: aggregating ? _y_aggregate : undefined
|
||||
},
|
||||
color: color
|
||||
? {
|
||||
field: color,
|
||||
legend: { orient: "bottom", title: color_title },
|
||||
scale:
|
||||
value.datatypes[color] === "nominal"
|
||||
? {
|
||||
domain: unique_colors,
|
||||
range: color_map
|
||||
? unique_colors.map((c) => color_map[c])
|
||||
: undefined
|
||||
}
|
||||
: {
|
||||
range: [
|
||||
100, 200, 300, 400, 500, 600, 700, 800, 900
|
||||
].map((n) =>
|
||||
computed_style.getPropertyValue("--primary-" + n)
|
||||
),
|
||||
interpolate: "hsl"
|
||||
},
|
||||
type: value.datatypes[color]
|
||||
}
|
||||
: undefined,
|
||||
tooltip: [
|
||||
{
|
||||
field: y,
|
||||
type: value.datatypes[y],
|
||||
aggregate: aggregating ? _y_aggregate : undefined,
|
||||
title: y_title || y
|
||||
},
|
||||
{
|
||||
field: x,
|
||||
type: value.datatypes[x],
|
||||
title: x_title || x,
|
||||
format: x_temporal ? "%Y-%m-%d %H:%M:%S" : undefined,
|
||||
bin: _x_bin ? { step: _x_bin } : undefined
|
||||
},
|
||||
...(color
|
||||
? [
|
||||
{
|
||||
field: color,
|
||||
type: value.datatypes[color]
|
||||
}
|
||||
]
|
||||
: [])
|
||||
]
|
||||
},
|
||||
strokeDash: {},
|
||||
mark: { clip: true, type: mode === "hover" ? "point" : value.mark },
|
||||
name: mode
|
||||
};
|
||||
}
|
||||
),
|
||||
// @ts-ignore
|
||||
params: [
|
||||
...(value.mark === "line"
|
||||
? [
|
||||
{
|
||||
name: "hoverPlot",
|
||||
select: {
|
||||
clear: "mouseout",
|
||||
fields: color ? [color] : [],
|
||||
nearest: true,
|
||||
on: "mouseover",
|
||||
type: "point" as "point"
|
||||
},
|
||||
views: ["hover"]
|
||||
},
|
||||
{
|
||||
name: "hover",
|
||||
select: {
|
||||
clear: "mouseout",
|
||||
nearest: true,
|
||||
on: "mouseover",
|
||||
type: "point" as "point"
|
||||
},
|
||||
views: ["hover"]
|
||||
}
|
||||
]
|
||||
: []),
|
||||
...(_selectable
|
||||
? [
|
||||
{
|
||||
name: "brush",
|
||||
select: {
|
||||
encodings: ["x"],
|
||||
mark: { fill: "gray", fillOpacity: 0.3, stroke: "none" },
|
||||
type: "interval" as "interval"
|
||||
},
|
||||
views: ["plot"]
|
||||
}
|
||||
]
|
||||
: [])
|
||||
],
|
||||
width: chart_element.offsetWidth,
|
||||
title: title || undefined
|
||||
};
|
||||
}
|
||||
|
||||
export let label = "Textbox";
|
||||
export let elem_id = "";
|
||||
export let elem_classes: string[] = [];
|
||||
export let visible = true;
|
||||
export let show_label: boolean;
|
||||
export let scale: number | null = null;
|
||||
export let min_width: number | undefined = undefined;
|
||||
export let loading_status: LoadingStatus | undefined = undefined;
|
||||
export let height: number | undefined = undefined;
|
||||
</script>
|
||||
|
||||
<Block
|
||||
{visible}
|
||||
{elem_id}
|
||||
{elem_classes}
|
||||
{scale}
|
||||
{min_width}
|
||||
allow_overflow={false}
|
||||
padding={true}
|
||||
{height}
|
||||
>
|
||||
{#if loading_status}
|
||||
<StatusTracker
|
||||
autoscroll={gradio.autoscroll}
|
||||
i18n={gradio.i18n}
|
||||
{...loading_status}
|
||||
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
|
||||
/>
|
||||
{/if}
|
||||
<BlockTitle {show_label} info={undefined}>{label}</BlockTitle>
|
||||
<div bind:this={chart_element}></div>
|
||||
{#if value}
|
||||
{#if caption}
|
||||
<p class="caption">{caption}</p>
|
||||
{/if}
|
||||
{:else}
|
||||
<Empty unpadded_box={true}><LabelIcon /></Empty>
|
||||
{/if}
|
||||
</Block>
|
||||
|
||||
<style>
|
||||
div {
|
||||
width: 100%;
|
||||
}
|
||||
:global(#vg-tooltip-element) {
|
||||
font-family: var(--font) !important;
|
||||
font-size: var(--text-xs) !important;
|
||||
box-shadow: none !important;
|
||||
background-color: var(--block-background-fill) !important;
|
||||
border: 1px solid var(--border-color-primary) !important;
|
||||
color: var(--body-text-color) !important;
|
||||
}
|
||||
:global(#vg-tooltip-element .key) {
|
||||
color: var(--body-text-color-subdued) !important;
|
||||
}
|
||||
.caption {
|
||||
padding: 0 4px;
|
||||
margin: 0;
|
||||
text-align: center;
|
||||
}
|
||||
</style>
|
28
js/nativeplot/package.json
Normal file
28
js/nativeplot/package.json
Normal file
@ -0,0 +1,28 @@
|
||||
{
|
||||
"name": "@gradio/nativeplot",
|
||||
"version": "0.0.1",
|
||||
"description": "Gradio UI packages",
|
||||
"type": "module",
|
||||
"author": "",
|
||||
"license": "ISC",
|
||||
"private": false,
|
||||
"main_changeset": true,
|
||||
"exports": {
|
||||
".": "./Index.svelte",
|
||||
"./example": "./Example.svelte",
|
||||
"./package.json": "./package.json"
|
||||
},
|
||||
"dependencies": {
|
||||
"@gradio/atoms": "workspace:^",
|
||||
"@gradio/icons": "workspace:^",
|
||||
"@gradio/statustracker": "workspace:^",
|
||||
"@gradio/utils": "workspace:^",
|
||||
"@gradio/theme": "workspace:^",
|
||||
"vega": "^5.23.0",
|
||||
"vega-embed": "^6.25.0",
|
||||
"vega-lite": "^5.12.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@gradio/preview": "workspace:^"
|
||||
}
|
||||
}
|
@ -122,6 +122,7 @@
|
||||
"@gradio/markdown": "workspace:^",
|
||||
"@gradio/model3d": "workspace:^",
|
||||
"@gradio/multimodaltextbox": "workspace:^",
|
||||
"@gradio/nativeplot": "workspace:^",
|
||||
"@gradio/number": "workspace:^",
|
||||
"@gradio/paramviewer": "workspace:^",
|
||||
"@gradio/plot": "workspace:^",
|
||||
|
47
pnpm-lock.yaml
generated
47
pnpm-lock.yaml
generated
@ -270,6 +270,9 @@ importers:
|
||||
'@gradio/multimodaltextbox':
|
||||
specifier: workspace:^
|
||||
version: link:js/multimodaltextbox
|
||||
'@gradio/nativeplot':
|
||||
specifier: workspace:^
|
||||
version: link:js/nativeplot
|
||||
'@gradio/number':
|
||||
specifier: workspace:^
|
||||
version: link:js/number
|
||||
@ -625,6 +628,9 @@ importers:
|
||||
'@gradio/multimodaltextbox':
|
||||
specifier: workspace:^
|
||||
version: link:../multimodaltextbox
|
||||
'@gradio/nativeplot':
|
||||
specifier: workspace:^
|
||||
version: link:../nativeplot
|
||||
'@gradio/number':
|
||||
specifier: workspace:^
|
||||
version: link:../number
|
||||
@ -1549,6 +1555,37 @@ importers:
|
||||
specifier: workspace:^
|
||||
version: link:../preview
|
||||
|
||||
js/nativeplot:
|
||||
dependencies:
|
||||
'@gradio/atoms':
|
||||
specifier: workspace:^
|
||||
version: link:../atoms
|
||||
'@gradio/icons':
|
||||
specifier: workspace:^
|
||||
version: link:../icons
|
||||
'@gradio/statustracker':
|
||||
specifier: workspace:^
|
||||
version: link:../statustracker
|
||||
'@gradio/theme':
|
||||
specifier: workspace:^
|
||||
version: link:../theme
|
||||
'@gradio/utils':
|
||||
specifier: workspace:^
|
||||
version: link:../utils
|
||||
vega:
|
||||
specifier: ^5.23.0
|
||||
version: 5.23.0
|
||||
vega-embed:
|
||||
specifier: ^6.25.0
|
||||
version: 6.25.0(vega-lite@5.12.0(vega@5.23.0))(vega@5.23.0)
|
||||
vega-lite:
|
||||
specifier: ^5.12.0
|
||||
version: 5.12.0(vega@5.23.0)
|
||||
devDependencies:
|
||||
'@gradio/preview':
|
||||
specifier: workspace:^
|
||||
version: link:../preview
|
||||
|
||||
js/number:
|
||||
dependencies:
|
||||
'@gradio/atoms':
|
||||
@ -6025,10 +6062,6 @@ packages:
|
||||
engines: {node: '>=12'}
|
||||
hasBin: true
|
||||
|
||||
escalade@3.1.1:
|
||||
resolution: {integrity: sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==}
|
||||
engines: {node: '>=6'}
|
||||
|
||||
escalade@3.1.2:
|
||||
resolution: {integrity: sha512-ErCHMCae19vR8vQGe50xIsVomy19rg6gFu3+r3jkEO46suLMWBksvVyoGgQV+jOfl84ZSOSlmv6Gxa89PmTGmA==}
|
||||
engines: {node: '>=6'}
|
||||
@ -14422,8 +14455,6 @@ snapshots:
|
||||
'@esbuild/win32-ia32': 0.21.0
|
||||
'@esbuild/win32-x64': 0.21.0
|
||||
|
||||
escalade@3.1.1: {}
|
||||
|
||||
escalade@3.1.2: {}
|
||||
|
||||
escape-html@1.0.3: {}
|
||||
@ -18217,7 +18248,7 @@ snapshots:
|
||||
d3-timer: 3.0.1
|
||||
vega-dataflow: 5.7.5
|
||||
vega-format: 1.1.1
|
||||
vega-functions: 5.13.2
|
||||
vega-functions: 5.14.0
|
||||
vega-runtime: 6.1.4
|
||||
vega-scenegraph: 4.10.2
|
||||
vega-util: 1.17.2
|
||||
@ -18565,7 +18596,7 @@ snapshots:
|
||||
yargs@17.7.2:
|
||||
dependencies:
|
||||
cliui: 8.0.1
|
||||
escalade: 3.1.1
|
||||
escalade: 3.1.2
|
||||
get-caller-file: 2.0.5
|
||||
require-directory: 2.1.1
|
||||
string-width: 4.2.3
|
||||
|
@ -1,5 +1,4 @@
|
||||
aiofiles>=22.0,<24.0
|
||||
altair>=5.0,<6.0
|
||||
anyio>=3.0,<5.0
|
||||
fastapi
|
||||
ffmpy
|
||||
|
@ -1,12 +1,14 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import vega_datasets
|
||||
|
||||
cars = vega_datasets.data.cars()
|
||||
stocks = vega_datasets.data.stocks()
|
||||
barley = vega_datasets.data.barley()
|
||||
simple = pd.DataFrame(
|
||||
{
|
||||
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
|
||||
"b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
|
||||
"c": [0] * 9,
|
||||
"date": [datetime.now() - timedelta(days=i) for i in range(9)],
|
||||
}
|
||||
)
|
||||
|
@ -1,118 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .plot_data import barley, simple
|
||||
|
||||
|
||||
class TestBarPlot:
|
||||
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
||||
def test_get_config(self):
|
||||
assert gr.BarPlot().get_config() == {
|
||||
"caption": None,
|
||||
"elem_id": None,
|
||||
"elem_classes": [],
|
||||
"interactive": None,
|
||||
"label": None,
|
||||
"name": "barplot",
|
||||
"bokeh_version": "3.0.3",
|
||||
"show_actions_button": False,
|
||||
"proxy_url": None,
|
||||
"show_label": False,
|
||||
"container": True,
|
||||
"min_width": 160,
|
||||
"scale": None,
|
||||
"value": None,
|
||||
"visible": True,
|
||||
"x": None,
|
||||
"y": None,
|
||||
"color": None,
|
||||
"vertical": True,
|
||||
"group": None,
|
||||
"title": None,
|
||||
"tooltip": None,
|
||||
"x_title": None,
|
||||
"y_title": None,
|
||||
"color_legend_title": None,
|
||||
"group_title": None,
|
||||
"color_legend_position": None,
|
||||
"height": None,
|
||||
"width": None,
|
||||
"y_lim": None,
|
||||
"x_label_angle": None,
|
||||
"y_label_angle": None,
|
||||
"sort": None,
|
||||
"_selectable": False,
|
||||
"key": None,
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
plot = gr.BarPlot(
|
||||
x="a",
|
||||
y="b",
|
||||
tooltip=["a", "b"],
|
||||
title="Made Up Bar Plot",
|
||||
x_title="Variable A",
|
||||
sort="x",
|
||||
)
|
||||
assert (output := plot.postprocess(simple))
|
||||
output = output.model_dump()
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
assert output["chart"] == "bar"
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["x"]["sort"] == "x"
|
||||
assert config["encoding"]["x"]["field"] == "a"
|
||||
assert config["encoding"]["x"]["title"] == "Variable A"
|
||||
assert config["encoding"]["y"]["field"] == "b"
|
||||
assert config["encoding"]["y"]["title"] == "b"
|
||||
|
||||
assert config["title"] == "Made Up Bar Plot"
|
||||
assert "height" not in config
|
||||
assert "width" not in config
|
||||
|
||||
def test_height_width(self):
|
||||
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
|
||||
assert (output := plot.postprocess(simple))
|
||||
output = output.model_dump()
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
|
||||
def test_ylim(self):
|
||||
plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100])
|
||||
assert (output := plot.postprocess(simple))
|
||||
output = output.model_dump()
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]}
|
||||
|
||||
def test_horizontal(self):
|
||||
output = gr.BarPlot(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
x_title="Variable A",
|
||||
y_title="Variable B",
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=["a", "b"],
|
||||
vertical=False,
|
||||
y_lim=[20, 100],
|
||||
).get_config()
|
||||
assert output["value"]["chart"] == "bar"
|
||||
config = json.loads(output["value"]["plot"])
|
||||
assert config["encoding"]["x"]["field"] == "b"
|
||||
assert config["encoding"]["x"]["scale"] == {"domain": [20, 100]}
|
||||
assert config["encoding"]["x"]["title"] == "Variable B"
|
||||
|
||||
assert config["encoding"]["y"]["field"] == "a"
|
||||
assert config["encoding"]["y"]["title"] == "Variable A"
|
||||
|
||||
def test_barplot_accepts_fn_as_value(self):
|
||||
plot = gr.BarPlot(
|
||||
value=lambda: barley.sample(frac=0.1, replace=False),
|
||||
x="year",
|
||||
y="yield",
|
||||
)
|
||||
assert isinstance(plot.value, dict)
|
||||
assert isinstance(plot.value["plot"], str)
|
@ -1,112 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .plot_data import stocks
|
||||
|
||||
|
||||
class TestLinePlot:
|
||||
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
||||
def test_get_config(self):
|
||||
assert gr.LinePlot().get_config() == {
|
||||
"caption": None,
|
||||
"elem_id": None,
|
||||
"elem_classes": [],
|
||||
"interactive": None,
|
||||
"label": None,
|
||||
"name": "lineplot",
|
||||
"bokeh_version": "3.0.3",
|
||||
"show_actions_button": False,
|
||||
"proxy_url": None,
|
||||
"show_label": False,
|
||||
"container": True,
|
||||
"min_width": 160,
|
||||
"scale": None,
|
||||
"value": None,
|
||||
"visible": True,
|
||||
"x": None,
|
||||
"y": None,
|
||||
"color": None,
|
||||
"stroke_dash": None,
|
||||
"overlay_point": None,
|
||||
"title": None,
|
||||
"tooltip": [],
|
||||
"x_title": None,
|
||||
"y_title": None,
|
||||
"color_legend_title": None,
|
||||
"stroke_dash_legend_title": None,
|
||||
"color_legend_position": None,
|
||||
"stroke_dash_legend_position": None,
|
||||
"height": None,
|
||||
"width": None,
|
||||
"x_lim": None,
|
||||
"y_lim": None,
|
||||
"x_label_angle": None,
|
||||
"y_label_angle": None,
|
||||
"_selectable": False,
|
||||
"key": None,
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
plot = gr.LinePlot(
|
||||
x="date",
|
||||
y="price",
|
||||
tooltip=["symbol", "price"],
|
||||
title="Stock Performance",
|
||||
x_title="Trading Day",
|
||||
)
|
||||
output = plot.postprocess(stocks).model_dump() # type: ignore
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
for layer in config["layer"]:
|
||||
assert layer["mark"]["type"] in ["line", "point"]
|
||||
assert layer["encoding"]["x"]["field"] == "date"
|
||||
assert layer["encoding"]["x"]["title"] == "Trading Day"
|
||||
assert layer["encoding"]["y"]["field"] == "price"
|
||||
|
||||
assert config["title"] == "Stock Performance"
|
||||
assert "height" not in config
|
||||
assert "width" not in config
|
||||
|
||||
def test_height_width(self):
|
||||
plot = gr.LinePlot(x="date", y="price", height=100, width=200)
|
||||
output = plot.postprocess(stocks).model_dump() # type: ignore
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
|
||||
def test_xlim_ylim(self):
|
||||
plot = gr.LinePlot(x="date", y="price", x_lim=[200, 400], y_lim=[300, 500])
|
||||
output = plot.postprocess(stocks).model_dump() # type: ignore
|
||||
config = json.loads(output["plot"])
|
||||
for layer in config["layer"]:
|
||||
assert layer["encoding"]["x"]["scale"] == {"domain": [200, 400]}
|
||||
assert layer["encoding"]["y"]["scale"] == {"domain": [300, 500]}
|
||||
|
||||
def test_color_encoding(self):
|
||||
plot = gr.LinePlot(
|
||||
x="date", y="price", tooltip="symbol", color="symbol", overlay_point=True
|
||||
)
|
||||
output = plot.postprocess(stocks).model_dump() # type: ignore
|
||||
config = json.loads(output["plot"])
|
||||
for layer in config["layer"]:
|
||||
assert layer["encoding"]["color"]["field"] == "symbol"
|
||||
assert layer["encoding"]["color"]["scale"] == {
|
||||
"domain": ["MSFT", "AMZN", "IBM", "GOOG", "AAPL"],
|
||||
"range": [0, 1, 2, 3, 4],
|
||||
}
|
||||
assert layer["encoding"]["color"]["type"] == "nominal"
|
||||
if layer["mark"]["type"] == "point":
|
||||
assert layer["encoding"]["opacity"] == {}
|
||||
|
||||
def test_lineplot_accepts_fn_as_value(self):
|
||||
plot = gr.LinePlot(
|
||||
value=lambda: stocks.sample(frac=0.1, replace=False),
|
||||
x="date",
|
||||
y="price",
|
||||
color="symbol",
|
||||
)
|
||||
assert isinstance(plot.value, dict)
|
||||
assert isinstance(plot.value["plot"], str)
|
32
test/components/test_native_plots.py
Normal file
32
test/components/test_native_plots.py
Normal file
@ -0,0 +1,32 @@
|
||||
import gradio as gr
|
||||
|
||||
from .plot_data import barley, simple
|
||||
|
||||
|
||||
class TestNativePlot:
|
||||
def test_plot_recognizes_correct_datatypes(self):
|
||||
plot = gr.BarPlot(
|
||||
value=simple,
|
||||
x="date",
|
||||
y="b",
|
||||
)
|
||||
assert plot.value["datatypes"]["date"] == "temporal"
|
||||
assert plot.value["datatypes"]["b"] == "quantitative"
|
||||
|
||||
plot = gr.BarPlot(
|
||||
value=simple,
|
||||
x="a",
|
||||
y="b",
|
||||
color="c",
|
||||
)
|
||||
assert plot.value["datatypes"]["a"] == "nominal"
|
||||
assert plot.value["datatypes"]["b"] == "quantitative"
|
||||
assert plot.value["datatypes"]["c"] == "quantitative"
|
||||
|
||||
def test_plot_accepts_fn_as_value(self):
|
||||
plot = gr.BarPlot(
|
||||
value=lambda: barley.sample(frac=0.1, replace=False),
|
||||
x="year",
|
||||
y="yield",
|
||||
)
|
||||
assert plot.value["mark"] == "bar"
|
@ -1,168 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .plot_data import cars
|
||||
|
||||
|
||||
class TestScatterPlot:
|
||||
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
||||
def test_get_config(self):
|
||||
print(gr.ScatterPlot().get_config())
|
||||
assert gr.ScatterPlot().get_config() == {
|
||||
"caption": None,
|
||||
"elem_id": None,
|
||||
"elem_classes": [],
|
||||
"interactive": None,
|
||||
"label": None,
|
||||
"name": "scatterplot",
|
||||
"bokeh_version": "3.0.3",
|
||||
"show_actions_button": False,
|
||||
"proxy_url": None,
|
||||
"show_label": False,
|
||||
"container": True,
|
||||
"min_width": 160,
|
||||
"scale": None,
|
||||
"value": None,
|
||||
"visible": True,
|
||||
"x": None,
|
||||
"y": None,
|
||||
"color": None,
|
||||
"size": None,
|
||||
"shape": None,
|
||||
"title": None,
|
||||
"tooltip": None,
|
||||
"x_title": None,
|
||||
"y_title": None,
|
||||
"color_legend_title": None,
|
||||
"size_legend_title": None,
|
||||
"shape_legend_title": None,
|
||||
"color_legend_position": None,
|
||||
"size_legend_position": None,
|
||||
"shape_legend_position": None,
|
||||
"height": None,
|
||||
"width": None,
|
||||
"x_lim": None,
|
||||
"y_lim": None,
|
||||
"x_label_angle": None,
|
||||
"y_label_angle": None,
|
||||
"_selectable": False,
|
||||
"key": None,
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
plot = gr.ScatterPlot(
|
||||
x="Horsepower",
|
||||
y="Miles_per_Gallon",
|
||||
tooltip="Name",
|
||||
title="Car Data",
|
||||
x_title="Horse",
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["x"]["field"] == "Horsepower"
|
||||
assert config["encoding"]["x"]["title"] == "Horse"
|
||||
assert config["encoding"]["y"]["field"] == "Miles_per_Gallon"
|
||||
assert config["title"] == "Car Data"
|
||||
assert "height" not in config
|
||||
assert "width" not in config
|
||||
|
||||
def test_no_interactive(self):
|
||||
plot = gr.ScatterPlot(
|
||||
x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert "selection" not in config
|
||||
|
||||
def test_height_width(self):
|
||||
plot = gr.ScatterPlot(
|
||||
x="Horsepower", y="Miles_per_Gallon", height=100, width=200
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
|
||||
def test_xlim_ylim(self):
|
||||
plot = gr.ScatterPlot(
|
||||
x="Horsepower", y="Miles_per_Gallon", x_lim=[200, 400], y_lim=[300, 500]
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["x"]["scale"] == {"domain": [200, 400]}
|
||||
assert config["encoding"]["y"]["scale"] == {"domain": [300, 500]}
|
||||
|
||||
def test_color_encoding(self):
|
||||
plot = gr.ScatterPlot(
|
||||
x="Horsepower",
|
||||
y="Miles_per_Gallon",
|
||||
tooltip="Name",
|
||||
title="Car Data",
|
||||
color="Origin",
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["color"]["field"] == "Origin"
|
||||
assert config["encoding"]["color"]["scale"] == {
|
||||
"domain": ["USA", "Europe", "Japan"],
|
||||
"range": [0, 1, 2],
|
||||
}
|
||||
assert config["encoding"]["color"]["type"] == "nominal"
|
||||
|
||||
def test_two_encodings(self):
|
||||
plot = gr.ScatterPlot(
|
||||
show_label=False,
|
||||
title="Two encodings",
|
||||
x="Horsepower",
|
||||
y="Miles_per_Gallon",
|
||||
color="Acceleration",
|
||||
shape="Origin",
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["color"]["field"] == "Acceleration"
|
||||
assert config["encoding"]["color"]["scale"] == {
|
||||
"domain": [cars.Acceleration.min(), cars.Acceleration.max()],
|
||||
"range": [0, 1],
|
||||
}
|
||||
assert config["encoding"]["color"]["type"] == "quantitative"
|
||||
|
||||
assert config["encoding"]["shape"]["field"] == "Origin"
|
||||
assert config["encoding"]["shape"]["type"] == "nominal"
|
||||
|
||||
def test_legend_position(self):
|
||||
plot = gr.ScatterPlot(
|
||||
show_label=False,
|
||||
title="Two encodings",
|
||||
x="Horsepower",
|
||||
y="Miles_per_Gallon",
|
||||
color="Acceleration",
|
||||
color_legend_position="none",
|
||||
color_legend_title="Foo",
|
||||
shape="Origin",
|
||||
shape_legend_position="none",
|
||||
shape_legend_title="Bar",
|
||||
size="Acceleration",
|
||||
size_legend_title="Accel",
|
||||
size_legend_position="none",
|
||||
)
|
||||
output = plot.postprocess(cars).model_dump() # type: ignore
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["color"]["legend"] is None
|
||||
assert config["encoding"]["shape"]["legend"] is None
|
||||
assert config["encoding"]["size"]["legend"] is None
|
||||
|
||||
def test_scatterplot_accepts_fn_as_value(self):
|
||||
plot = gr.ScatterPlot(
|
||||
value=lambda: cars.sample(frac=0.1, replace=False),
|
||||
x="Horsepower",
|
||||
y="Miles_per_Gallon",
|
||||
color="Origin",
|
||||
)
|
||||
assert isinstance(plot.value, dict)
|
||||
assert isinstance(plot.value["plot"], str)
|
Loading…
x
Reference in New Issue
Block a user