Refactor plots to drop altair and use vega.js directly (#8807)

* changes

* add changeset

* changes

* changes

* changes

* add changeset

* changes

* add changeset

* changes

* add changeset

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* changes

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/blocks.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* changes

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update gradio/components/native_plot.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2024-07-22 09:52:48 -07:00 committed by GitHub
parent 914b1935de
commit a238af4d68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 1204 additions and 688 deletions

View File

@ -0,0 +1,8 @@
---
"@gradio/app": minor
"@gradio/datetime": minor
"@gradio/nativeplot": minor
"gradio": minor
---
feat:Refactor plots to drop `altair` and use `vega.js` directly

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks(fill_width=True) as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -10,7 +10,7 @@ def xray_model(diseases, img):
def ct_model(diseases, img):
return [{disease: 0.1 for disease in diseases}]
with gr.Blocks() as demo:
with gr.Blocks(fill_width=True) as demo:
gr.Markdown(
"""
# Detect Disease From Scan

View File

@ -1,111 +1,77 @@
import gradio as gr
import pandas as pd
import numpy as np
from data import temp_sensor_data, food_rating_data
from vega_datasets import data
with gr.Blocks() as bar_plots:
with gr.Row():
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
end = gr.DateTime("2021-01-05 00:00:00", label="End")
apply_btn = gr.Button("Apply", scale=0)
with gr.Row():
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")
barley = data.barley()
simple = pd.DataFrame({
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
})
def bar_plot_fn(display):
if display == "simple":
return gr.BarPlot(
simple,
x="a",
y="b",
color=None,
group=None,
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
y_lim=[20, 100],
x_title=None,
y_title=None,
vertical=True,
)
elif display == "stacked":
return gr.BarPlot(
barley,
x="variety",
y="yield",
color="site",
group=None,
title="Barley Yield Data",
tooltip=['variety', 'site'],
y_lim=None,
x_title=None,
y_title=None,
vertical=True,
)
elif display == "grouped":
return gr.BarPlot(
barley.astype({"year": str}),
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
tooltip=["yield", "site", "year"],
y_lim=None,
x_title=None,
y_title=None,
vertical=True,
)
elif display == "simple-horizontal":
return gr.BarPlot(
simple,
x="a",
y="b",
color=None,
group=None,
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
y_lim=[20, 100],
x_title="Variable A",
y_title="Variable B",
vertical=False,
)
elif display == "stacked-horizontal":
return gr.BarPlot(
barley,
x="variety",
y="yield",
color="site",
group=None,
title="Barley Yield Data",
tooltip=['variety', 'site'],
y_lim=None,
x_title=None,
y_title=None,
vertical=False,
)
elif display == "grouped-horizontal":
return gr.BarPlot(
barley.astype({"year": str}),
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
group_title="",
tooltip=["yield", "site", "year"],
y_lim=None,
x_title=None,
y_title=None,
vertical=False
)
with gr.Blocks() as bar_plot:
display = gr.Dropdown(
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
value="simple",
label="Type of Bar Plot"
temp_by_time = gr.BarPlot(
temp_sensor_data,
x="time",
y="temperature",
)
plot = gr.BarPlot(show_label=False)
display.change(bar_plot_fn, inputs=display, outputs=plot)
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)
temp_by_time_location = gr.BarPlot(
temp_sensor_data,
x="time",
y="temperature",
color="location",
)
time_graphs = [temp_by_time, temp_by_time_location]
group_by.change(
lambda group: [gr.BarPlot(x_bin=None if group == "None" else group)] * len(time_graphs),
group_by,
time_graphs
)
aggregate.change(
lambda aggregate: [gr.BarPlot(y_aggregate=aggregate)] * len(time_graphs),
aggregate,
time_graphs
)
def rescale(select: gr.SelectData):
return select.index
rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])
for trigger in [apply_btn.click, rescale_evt.then]:
trigger(
lambda start, end: [gr.BarPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
)
with gr.Row():
price_by_cuisine = gr.BarPlot(
food_rating_data,
x="cuisine",
y="price",
)
with gr.Column(scale=0):
gr.Button("Sort $ > $$$").click(lambda: gr.BarPlot(sort="y"), None, price_by_cuisine)
gr.Button("Sort $$$ > $").click(lambda: gr.BarPlot(sort="-y"), None, price_by_cuisine)
gr.Button("Sort A > Z").click(lambda: gr.BarPlot(sort=["Chinese", "Italian", "Mexican"]), None, price_by_cuisine)
with gr.Row():
price_by_rating = gr.BarPlot(
food_rating_data,
x="rating",
y="price",
x_bin=1,
)
price_by_rating_color = gr.BarPlot(
food_rating_data,
x="rating",
y="price",
color="cuisine",
x_bin=1,
color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
)
if __name__ == "__main__":
bar_plot.launch()
bar_plots.launch()

20
demo/native_plots/data.py Normal file
View File

@ -0,0 +1,20 @@
import pandas as pd
from random import randint, choice, random
temp_sensor_data = pd.DataFrame(
{
"time": pd.date_range("2021-01-01", end="2021-01-05", periods=200),
"temperature": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)],
"humidity": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)],
"location": ["indoor", "outdoor"] * 100,
}
)
food_rating_data = pd.DataFrame(
{
"cuisine": [["Italian", "Mexican", "Chinese"][i % 3] for i in range(100)],
"rating": [random() * 4 + 0.5 * (i % 3) for i in range(100)],
"price": [randint(10, 50) + 4 * (i % 3) for i in range(100)],
"wait": [random() for i in range(100)],
}
)

View File

@ -1,82 +1,69 @@
import gradio as gr
from vega_datasets import data
import numpy as np
from data import temp_sensor_data, food_rating_data
stocks = data.stocks()
gapminder = data.gapminder()
gapminder = gapminder.loc[
gapminder.country.isin(["Argentina", "Australia", "Afghanistan"])
]
climate = data.climate()
seattle_weather = data.seattle_weather()
with gr.Blocks() as line_plots:
with gr.Row():
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
end = gr.DateTime("2021-01-05 00:00:00", label="End")
apply_btn = gr.Button("Apply", scale=0)
with gr.Row():
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")
def line_plot_fn(dataset):
if dataset == "stocks":
return gr.LinePlot(
stocks,
x="date",
y="price",
color="symbol",
x_lim=None,
y_lim=None,
stroke_dash=None,
tooltip=['date', 'price', 'symbol'],
overlay_point=False,
title="Stock Prices",
stroke_dash_legend_title=None,
)
elif dataset == "climate":
return gr.LinePlot(
climate,
x="DATE",
y="HLY-TEMP-NORMAL",
color=None,
x_lim=None,
y_lim=[250, 500],
stroke_dash=None,
tooltip=['DATE', 'HLY-TEMP-NORMAL'],
overlay_point=False,
title="Climate",
stroke_dash_legend_title=None,
)
elif dataset == "seattle_weather":
return gr.LinePlot(
seattle_weather,
x="date",
y="temp_min",
color=None,
x_lim=None,
y_lim=None,
stroke_dash=None,
tooltip=["weather", "date"],
overlay_point=True,
title="Seattle Weather",
stroke_dash_legend_title=None,
)
elif dataset == "gapminder":
return gr.LinePlot(
gapminder,
x="year",
y="life_expect",
color="country",
x_lim=[1950, 2010],
y_lim=None,
stroke_dash="cluster",
tooltip=['country', 'life_expect'],
overlay_point=False,
title="Life expectancy for countries",
)
with gr.Blocks() as line_plot:
dataset = gr.Dropdown(
choices=["stocks", "climate", "seattle_weather", "gapminder"],
value="stocks",
temp_by_time = gr.LinePlot(
temp_sensor_data,
x="time",
y="temperature",
)
plot = gr.LinePlot()
dataset.change(line_plot_fn, inputs=dataset, outputs=plot)
line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot)
temp_by_time_location = gr.LinePlot(
temp_sensor_data,
x="time",
y="temperature",
color="location",
)
time_graphs = [temp_by_time, temp_by_time_location]
group_by.change(
lambda group: [gr.LinePlot(x_bin=None if group == "None" else group)] * len(time_graphs),
group_by,
time_graphs
)
aggregate.change(
lambda aggregate: [gr.LinePlot(y_aggregate=aggregate)] * len(time_graphs),
aggregate,
time_graphs
)
def rescale(select: gr.SelectData):
return select.index
rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])
for trigger in [apply_btn.click, rescale_evt.then]:
trigger(
lambda start, end: [gr.LinePlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
)
price_by_cuisine = gr.LinePlot(
food_rating_data,
x="cuisine",
y="price",
)
with gr.Row():
price_by_rating = gr.LinePlot(
food_rating_data,
x="rating",
y="price",
)
price_by_rating_color = gr.LinePlot(
food_rating_data,
x="rating",
y="price",
color="cuisine",
color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
)
if __name__ == "__main__":
line_plot.launch()
line_plots.launch()

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "from bar_plot_demo import bar_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/data.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plots\n", "from line_plot_demo import line_plots\n", "from bar_plot_demo import bar_plots\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plots.render()\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plots.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plots.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -1,18 +1,18 @@
import gradio as gr
from scatter_plot_demo import scatter_plot
from line_plot_demo import line_plot
from bar_plot_demo import bar_plot
from scatter_plot_demo import scatter_plots
from line_plot_demo import line_plots
from bar_plot_demo import bar_plots
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Scatter Plot"):
scatter_plot.render()
with gr.TabItem("Line Plot"):
line_plot.render()
line_plots.render()
with gr.TabItem("Scatter Plot"):
scatter_plots.render()
with gr.TabItem("Bar Plot"):
bar_plot.render()
bar_plots.render()
if __name__ == "__main__":
demo.launch()

View File

@ -1,46 +1,71 @@
import gradio as gr
import numpy as np
from data import temp_sensor_data, food_rating_data
from vega_datasets import data
with gr.Blocks() as scatter_plots:
with gr.Row():
start = gr.DateTime("2021-01-01 00:00:00", label="Start")
end = gr.DateTime("2021-01-05 00:00:00", label="End")
apply_btn = gr.Button("Apply", scale=0)
with gr.Row():
group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by")
aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation")
cars = data.cars()
iris = data.iris()
temp_by_time = gr.ScatterPlot(
temp_sensor_data,
x="time",
y="temperature",
)
temp_by_time_location = gr.ScatterPlot(
temp_sensor_data,
x="time",
y="temperature",
color="location",
)
time_graphs = [temp_by_time, temp_by_time_location]
group_by.change(
lambda group: [gr.ScatterPlot(x_bin=None if group == "None" else group)] * len(time_graphs),
group_by,
time_graphs
)
aggregate.change(
lambda aggregate: [gr.ScatterPlot(y_aggregate=aggregate)] * len(time_graphs),
aggregate,
time_graphs
)
def scatter_plot_fn(dataset):
if dataset == "iris":
return gr.ScatterPlot(
value=iris,
x="petalWidth",
y="petalLength",
color=None,
title="Iris Dataset",
x_title="Petal Width",
y_title="Petal Length",
tooltip=["petalWidth", "petalLength", "species"],
caption="",
height=600,
width=600,
# def rescale(select: gr.SelectData):
# return select.index
# rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end])
# for trigger in [apply_btn.click, rescale_evt.then]:
# trigger(
# lambda start, end: [gr.ScatterPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs
# )
price_by_cuisine = gr.ScatterPlot(
food_rating_data,
x="cuisine",
y="price",
)
with gr.Row():
price_by_rating = gr.ScatterPlot(
food_rating_data,
x="rating",
y="price",
color="wait",
show_actions_button=True,
)
else:
return gr.ScatterPlot(
value=cars,
x="Horsepower",
y="Miles_per_Gallon",
color="Origin",
tooltip="Name",
title="Car Data",
y_title="Miles per Gallon",
caption="MPG vs Horsepower of various cars",
height=None,
width=None,
price_by_rating_color = gr.ScatterPlot(
food_rating_data,
x="rating",
y="price",
color="cuisine",
# color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"},
)
with gr.Blocks() as scatter_plot:
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
plot = gr.ScatterPlot(show_label=False)
dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)
if __name__ == "__main__":
scatter_plot.launch()
scatter_plots.launch()

View File

@ -942,6 +942,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
js: str | None = None,
head: str | None = None,
fill_height: bool = False,
fill_width: bool = False,
delete_cache: tuple[int, int] | None = None,
**kwargs,
):
@ -955,6 +956,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
js: Custom js as a string or path to a js file. The custom js should be in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside <script> tags.
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, multiple scripts, stylesheets, etc. to the page.
fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width. Only applies if this is the outermost `Blocks` in your Gradio app.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
self.limiter = None
@ -988,6 +990,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.show_error = True
self.head = head
self.fill_height = fill_height
self.fill_width = fill_width
self.delete_cache = delete_cache
if css is not None and os.path.exists(css):
with open(css, encoding="utf-8") as css_file:
@ -2036,6 +2039,7 @@ Received outputs:
),
},
"fill_height": self.fill_height,
"fill_width": self.fill_width,
"theme_hash": self.theme_hash,
}
config.update(self.default_config.get_config()) # type: ignore

View File

@ -86,6 +86,7 @@ class ChatInterface(Blocks):
fill_height: bool = True,
delete_cache: tuple[int, int] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "minimal",
fill_width: bool = False,
):
"""
Parameters:
@ -116,6 +117,7 @@ class ChatInterface(Blocks):
fill_height: If True, the chat interface will expand to the height of window.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
show_progress: whether to show progress animation while running.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -126,6 +128,7 @@ class ChatInterface(Blocks):
js=js,
head=head,
fill_height=fill_height,
fill_width=fill_width,
delete_cache=delete_cache,
)
self.type: Literal["messages", "tuples"] = type

View File

@ -87,7 +87,7 @@ OVERRIDES = {
"Plot": ComponentFiles(template="Plot", demo_code=static_only_demo_code),
"BarPlot": ComponentFiles(
template="BarPlot",
python_file_name="bar_plot.py",
python_file_name="native_plot.py",
js_dir="plot",
demo_code=static_only_demo_code,
),
@ -121,7 +121,7 @@ OVERRIDES = {
),
"LinePlot": ComponentFiles(
template="LinePlot",
python_file_name="line_plot.py",
python_file_name="native_plot.py",
js_dir="plot",
demo_code=static_only_demo_code,
),
@ -139,7 +139,7 @@ OVERRIDES = {
),
"ScatterPlot": ComponentFiles(
template="ScatterPlot",
python_file_name="scatter_plot.py",
python_file_name="native_plot.py",
js_dir="plot",
demo_code=static_only_demo_code,
),

View File

@ -1,6 +1,5 @@
from gradio.components.annotated_image import AnnotatedImage
from gradio.components.audio import Audio
from gradio.components.bar_plot import BarPlot
from gradio.components.base import (
Component,
FormComponent,
@ -33,17 +32,16 @@ from gradio.components.image import Image
from gradio.components.image_editor import ImageEditor
from gradio.components.json_component import JSON
from gradio.components.label import Label
from gradio.components.line_plot import LinePlot
from gradio.components.login_button import LoginButton
from gradio.components.logout_button import LogoutButton
from gradio.components.markdown import Markdown
from gradio.components.model3d import Model3D
from gradio.components.multimodal_textbox import MultimodalTextbox
from gradio.components.native_plot import BarPlot, LinePlot, ScatterPlot
from gradio.components.number import Number
from gradio.components.paramviewer import ParamViewer
from gradio.components.plot import Plot
from gradio.components.radio import Radio
from gradio.components.scatter_plot import ScatterPlot
from gradio.components.slider import Slider
from gradio.components.state import State
from gradio.components.textbox import Textbox

View File

@ -0,0 +1,263 @@
from __future__ import annotations
import json
import warnings
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
List,
Literal,
Sequence,
)
import pandas as pd
from gradio_client.documentation import document
from gradio.components.base import Component
from gradio.data_classes import GradioModel
from gradio.events import Events
if TYPE_CHECKING:
from gradio.components import Timer
class PlotData(GradioModel):
columns: List[str]
data: List[List[Any]]
datatypes: Dict[str, Literal["quantitative", "nominal", "temporal"]]
mark: str
class NativePlot(Component):
"""
Creates a native Gradio plot component to display data from a pandas DataFrame. Supports interactivity and updates.
Demos: native_plots
"""
EVENTS = [Events.select]
def __init__(
self,
value: pd.DataFrame | Callable | None = None,
x: str | None = None,
y: str | None = None,
*,
color: str | None = None,
title: str | None = None,
x_title: str | None = None,
y_title: str | None = None,
color_title: str | None = None,
x_bin: str | float | None = None,
y_aggregate: Literal["sum", "mean", "median", "min", "max", "count"]
| None = None,
color_map: dict[str, str] | None = None,
x_lim: list[float] | None = None,
y_lim: list[float] | None = None,
caption: str | None = None,
sort: Literal["x", "y", "-x", "-y"] | list[str] | None = None,
height: int | None = None,
label: str | None = None,
show_label: bool | None = None,
container: bool = True,
scale: int | None = None,
min_width: int = 160,
every: Timer | float | None = None,
inputs: Component | Sequence[Component] | AbstractSet[Component] | None = None,
visible: bool = True,
elem_id: str | None = None,
elem_classes: list[str] | str | None = None,
render: bool = True,
key: int | str | None = None,
**kwargs,
):
"""
Parameters:
value: The pandas dataframe containing the data to display in the plot.
x: Column corresponding to the x axis. Column can be numeric, datetime, or string/category.
y: Column corresponding to the y axis. Column must be numeric.
color: Column corresponding to series, visualized by color. Column must be string/category.
title: The title to display on top of the chart.
x_title: The title given to the x axis. By default, uses the value of the x parameter.
y_title: The title given to the y axis. By default, uses the value of the y parameter.
color_title: The title given to the color legend. By default, uses the value of color parameter.
x_bin: Grouping used to cluster x values. If x column is numeric, should be number to bin the x values. If x column is datetime, should be string such as "1h", "15m", "10s", using "s", "m", "h", "d" suffixes.
y_aggregate: Aggregation function used to aggregate y values, used if x_bin is provided or x is a string/category. Must be one of "sum", "mean", "median", "min", "max".
color_map: Mapping of series to color names or codes. For example, {"success": "green", "fail": "#FF8888"}.
height: The height of the plot in pixels.
x_lim: A tuple or list containing the limits for the x-axis, specified as [x_min, x_max]. If x column is datetime type, x_lim should be timestamps.
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
caption: The (optional) caption to display below the plot.
sort: The sorting order of the x values, if x column is type string/category. Can be "x", "y", "-x", "-y", or list of strings that represent the order of the categories.
height: The height of the plot in pixels.
label: The (optional) label to display on the top left corner of the plot.
show_label: Whether the label should be displayed.
container: If True, will place the component in a container - providing some extra padding around the border.
scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
every: Continously calls `value` to recalculate it if `value` is a function (has no effect otherwise). Can provide a Timer whose tick resets `value`, or a float that provides the regular interval for the reset Timer.
inputs: Components that are used as inputs to calculate `value` if `value` is a function (has no effect otherwise). `value` is recalculated any time the inputs change.
visible: Whether the plot should be visible.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
render: If False, component will not render be rendered in the Blocks context. Should be used if the intention is to assign event listeners now but render the component later.
key: if assigned, will be used to assume identity across a re-render. Components that have the same key across a re-render will have their value preserved.
"""
self.x = x
self.y = y
self.color = color
self.title = title
self.x_title = x_title
self.y_title = y_title
self.color_title = color_title
self.x_bin = x_bin
self.y_aggregate = y_aggregate
self.color_map = color_map
self.x_lim = x_lim
self.y_lim = y_lim
self.caption = caption
self.sort = sort
self.height = height
if label is None and show_label is None:
show_label = False
super().__init__(
value=value,
label=label,
show_label=show_label,
container=container,
scale=scale,
min_width=min_width,
visible=visible,
elem_id=elem_id,
elem_classes=elem_classes,
render=render,
key=key,
every=every,
inputs=inputs,
)
for key, val in kwargs.items():
if key == "color_legend_title":
self.color_title = val
if key in [
"stroke_dash",
"overlay_point",
"tooltip",
"x_label_angle",
"y_label_angle",
"interactive",
"show_actions_button",
"color_legend_title",
"width",
]:
warnings.warn(
f"Argument '{key}' has been deprecated.", DeprecationWarning
)
def get_block_name(self) -> str:
return "nativeplot"
def get_mark(self) -> str:
return "native"
def preprocess(self, payload: PlotData | None) -> PlotData | None:
"""
Parameters:
payload: The data to display in a line plot.
Returns:
The data to display in a line plot.
"""
return payload
def postprocess(self, value: pd.DataFrame | dict | None) -> PlotData | None:
"""
Parameters:
value: Expects a pandas DataFrame containing the data to display in the line plot. The DataFrame should contain at least two columns, one for the x-axis (corresponding to this component's `x` argument) and one for the y-axis (corresponding to `y`).
Returns:
The data to display in a line plot, in the form of an AltairPlotData dataclass, which includes the plot information as a JSON string, as well as the type of plot (in this case, "line").
"""
# if None or update
if value is None or isinstance(value, dict):
return value
def get_simplified_type(dtype):
if pd.api.types.is_numeric_dtype(dtype):
return "quantitative"
elif pd.api.types.is_string_dtype(
dtype
) or pd.api.types.is_categorical_dtype(dtype):
return "nominal"
elif pd.api.types.is_datetime64_any_dtype(dtype):
return "temporal"
else:
raise ValueError(f"Unsupported data type: {dtype}")
split_json = json.loads(value.to_json(orient="split", date_unit="ms"))
datatypes = {
col: get_simplified_type(value[col].dtype) for col in value.columns
}
return PlotData(
columns=split_json["columns"],
data=split_json["data"],
datatypes=datatypes,
mark=self.get_mark(),
)
def example_payload(self) -> Any:
return None
def example_value(self) -> Any:
import pandas as pd
return pd.DataFrame({self.x: [1, 2, 3], self.y: [4, 5, 6]})
def api_info(self) -> dict[str, Any]:
return {"type": {}, "description": "any valid json"}
@document()
class BarPlot(NativePlot):
"""
Creates a bar plot component to display data from a pandas DataFrame.
Demos: native_plots
"""
def get_block_name(self) -> str:
return "nativeplot"
def get_mark(self) -> str:
return "bar"
@document()
class LinePlot(NativePlot):
"""
Creates a line plot component to display data from a pandas DataFrame.
Demos: native_plots
"""
def get_block_name(self) -> str:
return "nativeplot"
def get_mark(self) -> str:
return "line"
@document()
class ScatterPlot(NativePlot):
"""
Creates a scatter plot component to display data from a pandas DataFrame.
Demos: native_plots
"""
def get_block_name(self) -> str:
return "nativeplot"
def get_mark(self) -> str:
return "point"

View File

@ -315,6 +315,7 @@ class BlocksConfigDict(TypedDict):
protocol: Literal["ws", "sse", "sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
body_css: BodyCSS
fill_height: bool
fill_width: bool
theme_hash: str
layout: NotRequired[Layout]
dependencies: NotRequired[list[dict[str, Any]]]

View File

@ -133,6 +133,7 @@ class Interface(Blocks):
delete_cache: tuple[int, int] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "full",
example_labels: list[str] | None = None,
fill_width: bool = False,
**kwargs,
):
"""
@ -170,6 +171,7 @@ class Interface(Blocks):
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
show_progress: whether to show progress animation while running. Has no effect if the interface is `live`.
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -180,6 +182,7 @@ class Interface(Blocks):
js=js,
head=head,
delete_cache=delete_cache,
fill_width=fill_width,
**kwargs,
)
self.api_name: str | Literal[False] | None = api_name

View File

@ -8,13 +8,16 @@ import gradio as gr
data = {"data": {}}
with gr.Blocks() as demo:
gr.Markdown("# Monitoring Dashboard")
timer = gr.Timer(5)
with gr.Row():
selected_function = gr.Dropdown(
start = gr.DateTime("now - 24h", label="Start Time")
end = gr.DateTime("now", label="End Time")
selected_fn = gr.Dropdown(
["All"],
value="All",
label="Endpoint",
info="Select the function to see analytics for, or 'All' for aggregate.",
scale=2,
)
demo.load(
lambda: gr.Dropdown(
@ -22,70 +25,78 @@ with gr.Blocks() as demo:
+ list({row["function"] for row in data["data"].values()}) # type: ignore
),
None,
selected_function,
)
timespan = gr.Dropdown(
["All Time", "24 hours", "1 hours", "10 minutes"],
value="All Time",
label="Timespan",
info="Duration to see data for.",
selected_fn,
)
with gr.Group():
with gr.Row():
unique_users = gr.Label(label="Unique Users")
unique_requests = gr.Label(label="Unique Requests")
total_requests = gr.Label(label="Total Requests")
process_time = gr.Label(label="Avg Process Time")
plot = gr.BarPlot(
x="time",
y="count",
y="function",
color="status",
title="Requests over Time",
y_title="Requests",
width=600,
x_bin="1m",
y_aggregate="count",
color_map={
"success": "#22c55e",
"failure": "#ef4444",
"pending": "#eab308",
"queued": "#3b82f6",
},
)
@gr.on(
[demo.load, selected_function.change, timespan.change],
inputs=[selected_function, timespan],
outputs=[unique_users, unique_requests, process_time, plot],
[demo.load, timer.tick, start.change, end.change, selected_fn.change],
inputs=[start, end, selected_fn],
outputs=[plot, unique_users, total_requests, process_time],
)
def load_dfs(function, timespan):
df = pd.DataFrame(data["data"].values())
if df.empty:
return 0, 0, 0, gr.skip()
def gen_plot(start, end, selected_fn):
df = pd.DataFrame(list(data["data"].values()))
if selected_fn != "All":
df = df[df["function"] == selected_fn]
df = df[(df["time"] >= start) & (df["time"] <= end)]
df["time"] = pd.to_datetime(df["time"], unit="s")
df_filtered = df if function == "All" else df[df["function"] == function]
if timespan != "All Time":
df_filtered = df_filtered[
df_filtered["time"] > pd.Timestamp.now() - pd.Timedelta(timespan) # type: ignore
]
df_filtered["time"] = df_filtered["time"].dt.floor("min") # type: ignore
plot = df_filtered.groupby(["time", "status"]).size().reset_index(name="count") # type: ignore
mean_process_time_for_success = df_filtered[df_filtered["status"] == "success"][
"process_time"
].mean()
unique_users = len(df["session_hash"].unique())
total_requests = len(df)
process_time = round(df["process_time"].mean(), 2)
duration = end - start
x_bin = (
"1h"
if duration >= 60 * 60 * 24
else "15m"
if duration >= 60 * 60 * 3
else "1m"
)
df = df.drop(columns=["session_hash"])
return (
df_filtered["session_hash"].nunique(), # type: ignore
df_filtered.shape[0],
round(mean_process_time_for_success, 2),
plot,
gr.BarPlot(value=df, x_bin=x_bin, x_lim=[start, end]),
unique_users,
total_requests,
process_time,
)
if __name__ == "__main__":
data["data"] = {
random.randint(0, 1000000): {
"time": time.time() - random.randint(0, 60 * 60 * 24 * 3),
data["data"] = {}
for _ in range(random.randint(300, 500)):
timedelta = random.randint(0, 60 * 60 * 24 * 3)
data["data"][random.randint(1, 100000)] = {
"time": time.time() - timedelta,
"status": random.choice(
["success", "success", "failure", "pending", "queued"]
["success", "success", "failure"]
if timedelta > 30 * 60
else ["queued", "pending"]
),
"function": random.choice(["predict", "chat", "chat"]),
"process_time": random.randint(0, 10),
"session_hash": str(random.randint(0, 4)),
}
for r in range(random.randint(100, 200))
}
demo.launch()

View File

@ -563,7 +563,8 @@ def get_all_components() -> list[type[Component] | type[BlockContext]]:
return [
c
for c in subclasses
if c.__name__ not in ["ChatInterface", "Interface", "Blocks", "TabbedInterface"]
if c.__name__
not in ["ChatInterface", "Interface", "Blocks", "TabbedInterface", "NativePlot"]
]

View File

@ -57,6 +57,7 @@
"@gradio/markdown": "workspace:^",
"@gradio/model3d": "workspace:^",
"@gradio/multimodaltextbox": "workspace:^",
"@gradio/nativeplot": "workspace:^",
"@gradio/number": "workspace:^",
"@gradio/paramviewer": "workspace:^",
"@gradio/plot": "workspace:^",

View File

@ -4,6 +4,7 @@
export let wrapper: HTMLDivElement;
export let version: string;
export let initial_height: string;
export let fill_width: boolean;
export let is_embed: boolean;
export let space: string | null;
@ -15,6 +16,7 @@
<div
bind:this={wrapper}
class:app={!display && !is_embed}
class:fill_width
class:embed-container={display}
class:with-info={info}
class="gradio-container gradio-container-{version}"
@ -87,27 +89,27 @@
}
@media (--screen-sm) {
.app {
.app:not(.fill_width) {
max-width: 640px;
}
}
@media (--screen-md) {
.app {
.app:not(.fill_width) {
max-width: 768px;
}
}
@media (--screen-lg) {
.app {
.app:not(.fill_width) {
max-width: 1024px;
}
}
@media (--screen-xl) {
.app {
.app:not(.fill_width) {
max-width: 1280px;
}
}
@media (--screen-xxl) {
.app {
.app:not(.fill_width) {
max-width: 1536px;
}
}

View File

@ -29,6 +29,7 @@
path: string;
app_id?: string;
fill_height?: boolean;
fill_width?: boolean;
theme_hash?: number;
username: string | null;
}
@ -412,6 +413,7 @@
{initial_height}
{space}
loaded={loader_status === "complete"}
fill_width={config?.fill_width || false}
bind:wrapper
>
{#if (loader_status === "pending" || loader_status === "error") && !(config && config?.auth_required)}

View File

@ -25,6 +25,8 @@
export let include_time = true;
$: if (value !== old_value) {
old_value = value;
entered_value = value;
datevalue = value;
gradio.dispatch("change");
}
@ -49,7 +51,7 @@
let entered_value = value;
let datetime: HTMLInputElement;
let datevalue = "";
let datevalue = value;
const date_is_valid_format = (date: string): boolean => {
if (date === "") return false;
@ -156,7 +158,6 @@
flex-shrink: 1;
display: flex;
position: relative;
box-shadow: var(--input-shadow);
background: var(--input-background-fill);
}
.timebox :global(svg) {
@ -175,6 +176,7 @@
border-right: none;
border-top-left-radius: var(--input-radius);
border-bottom-left-radius: var(--input-radius);
box-shadow: var(--input-shadow);
}
.time.invalid {
color: var(--body-text-color-subdued);
@ -193,6 +195,13 @@
border-top-right-radius: var(--input-radius);
border-bottom-right-radius: var(--input-radius);
padding: var(--size-2);
border: var(--input-border-width) solid var(--input-border-color);
}
.calendar:hover {
background: var(--button-secondary-background-fill-hover);
}
.calendar:active {
box-shadow: var(--button-shadow-active);
}
.datetime {
width: 0px;

View File

@ -0,0 +1 @@
# @gradio/nativeplot

View File

@ -0,0 +1,11 @@
<script lang="ts">
export let title: string | null;
export let x: string;
export let y: string;
</script>
{#if title}
{title}
{:else}
{x} x {y}
{/if}

505
js/nativeplot/Index.svelte Normal file
View File

@ -0,0 +1,505 @@
<script lang="ts">
import type { Gradio, SelectData } from "@gradio/utils";
import { BlockTitle } from "@gradio/atoms";
import { Block } from "@gradio/atoms";
import { StatusTracker } from "@gradio/statustracker";
import type { LoadingStatus } from "@gradio/statustracker";
import { onMount } from "svelte";
import type { TopLevelSpec as Spec } from "vega-lite";
import vegaEmbed from "vega-embed";
import type { View } from "vega";
import { LineChart as LabelIcon } from "@gradio/icons";
import { Empty } from "@gradio/atoms";
interface PlotData {
columns: string[];
data: [string | number][];
datatypes: Record<string, "quantitative" | "temporal" | "nominal">;
mark: "line" | "point" | "bar";
}
export let value: PlotData | null;
export let x: string;
export let y: string;
export let color: string | null = null;
$: unique_colors =
color && value && value.datatypes[color] === "nominal"
? Array.from(new Set(_data.map((d) => d[color])))
: [];
export let title: string | null = null;
export let x_title: string | null = null;
export let y_title: string | null = null;
export let color_title: string | null = null;
export let x_bin: string | number | null = null;
export let y_aggregate:
| "sum"
| "mean"
| "median"
| "min"
| "max"
| undefined = undefined;
export let color_map: Record<string, string> | null = null;
export let x_lim: [number, number] | null = null;
export let y_lim: [number, number] | null = null;
export let caption: string | null = null;
export let sort: "x" | "y" | "-x" | "-y" | string[] | null = null;
function reformat_sort(
_sort: typeof sort
):
| string
| "ascending"
| "descending"
| { field: string; order: "ascending" | "descending" }
| string[]
| undefined {
if (_sort === "x") {
return "ascending";
} else if (_sort === "-x") {
return "descending";
} else if (_sort === "y") {
return { field: y, order: "ascending" };
} else if (_sort === "-y") {
return { field: y, order: "descending" };
} else if (_sort === null) {
return undefined;
} else if (Array.isArray(_sort)) {
return _sort;
}
}
$: _sort = reformat_sort(sort);
export let _selectable = false;
export let target: HTMLDivElement;
let _data: {
[x: string]: string | number;
}[];
export let gradio: Gradio<{
select: SelectData;
clear_status: LoadingStatus;
}>;
$: x_temporal = value && value.datatypes[x] === "temporal";
$: _x_lim = x_lim && x_temporal ? [x_lim[0] * 1000, x_lim[1] * 1000] : x_lim;
let _x_bin: number | undefined;
let mouse_down_on_chart = false;
const SUFFIX_DURATION: Record<string, number> = {
s: 1,
m: 60,
h: 60 * 60,
d: 24 * 60 * 60
};
$: _x_bin = x_bin
? typeof x_bin === "string"
? 1000 *
parseInt(x_bin.substring(0, x_bin.length - 1)) *
SUFFIX_DURATION[x_bin[x_bin.length - 1]]
: x_bin
: undefined;
let _y_aggregate: typeof y_aggregate;
let aggregating: boolean;
$: {
if (value) {
if (value.mark === "point") {
aggregating = _x_bin !== undefined;
_y_aggregate = y_aggregate || aggregating ? "sum" : undefined;
} else {
aggregating = _x_bin !== undefined || value.datatypes[x] === "nominal";
_y_aggregate = y_aggregate ? y_aggregate : "sum";
}
}
}
function reformat_data(data: PlotData): {
[x: string]: string | number;
}[] {
let x_index = data.columns.indexOf(x);
let y_index = data.columns.indexOf(y);
let color_index = color ? data.columns.indexOf(color) : null;
return data.data.map((row) => {
const obj = {
[x]: row[x_index],
[y]: row[y_index]
};
if (color && color_index !== null) {
obj[color] = row[color_index];
}
return obj;
});
}
$: _data = value ? reformat_data(value) : [];
let chart_element: HTMLDivElement;
let computed_style = window.getComputedStyle(target);
let view: View;
let mounted = false;
let old_width: number;
function load_chart(): void {
if (view) {
view.finalize();
}
if (!value) return;
old_width = chart_element.offsetWidth;
const spec = create_vega_lite_spec();
if (!spec) return;
let resizeObserver = new ResizeObserver(() => {
if (
old_width === 0 &&
chart_element.offsetWidth !== 0 &&
value.datatypes[x] === "nominal"
) {
// a bug where when a nominal chart is first loaded, the width is 0, it doesn't resize
load_chart();
} else {
view.signal("width", chart_element.offsetWidth).run();
}
});
vegaEmbed(chart_element, spec, { actions: false }).then(function (result) {
view = result.view;
resizeObserver.observe(chart_element);
var debounceTimeout: NodeJS.Timeout;
if (_selectable) {
view.addSignalListener("brush", function (_, value) {
if (Object.keys(value).length === 0) return;
clearTimeout(debounceTimeout);
let range: [number, number] = value[Object.keys(value)[0]];
if (x_temporal) {
range = [range[0] / 1000, range[1] / 1000];
}
let callback = (): void => {
gradio.dispatch("select", {
value: range,
index: range,
selected: true
});
};
if (mouse_down_on_chart) {
release_callback = callback;
} else {
debounceTimeout = setTimeout(function () {
gradio.dispatch("select", {
value: range,
index: range,
selected: true
});
}, 250);
}
});
}
});
}
let release_callback: (() => void) | null = null;
onMount(() => {
mounted = true;
chart_element.addEventListener("mousedown", () => {
mouse_down_on_chart = true;
});
chart_element.addEventListener("mouseup", () => {
mouse_down_on_chart = false;
if (release_callback) {
release_callback();
release_callback = null;
}
});
});
$: title,
x_title,
y_title,
color_title,
x,
y,
color,
x_bin,
_y_aggregate,
color_map,
x_lim,
y_lim,
caption,
sort,
mounted && load_chart();
function create_vega_lite_spec(): Spec | null {
if (!value) return null;
let accent_color = computed_style.getPropertyValue("--color-accent");
let body_text_color = computed_style.getPropertyValue("--body-text-color");
let borderColorPrimary = computed_style.getPropertyValue(
"--border-color-primary"
);
let font_family = computed_style.fontFamily;
let title_weight = computed_style.getPropertyValue(
"--block-title-text-weight"
) as
| "bold"
| "normal"
| 100
| 200
| 300
| 400
| 500
| 600
| 700
| 800
| 900;
const font_to_px_val = (font: string): number => {
return font.endsWith("px") ? parseFloat(font.slice(0, -2)) : 12;
};
let text_size_md = font_to_px_val(
computed_style.getPropertyValue("--text-md")
);
let text_size_sm = font_to_px_val(
computed_style.getPropertyValue("--text-sm")
);
return {
$schema: "https://vega.github.io/schema/vega-lite/v5.17.0.json",
background: "transparent",
config: {
autosize: { type: "fit", contains: "padding" },
axis: {
labelFont: font_family,
labelColor: body_text_color,
titleFont: font_family,
titleColor: body_text_color,
titlePadding: 8,
tickColor: borderColorPrimary,
labelFontSize: text_size_sm,
gridColor: borderColorPrimary,
titleFontWeight: "normal",
titleFontSize: text_size_sm,
labelFontWeight: "normal",
domain: false,
labelAngle: 0
},
legend: {
labelColor: body_text_color,
labelFont: font_family,
titleColor: body_text_color,
titleFont: font_family,
titleFontWeight: "normal",
titleFontSize: text_size_sm,
labelFontWeight: "normal",
offset: 2
},
title: {
color: body_text_color,
font: font_family,
fontSize: text_size_md,
fontWeight: title_weight,
anchor: "middle"
},
view: { stroke: borderColorPrimary },
mark: {
stroke: value.mark !== "bar" ? accent_color : undefined,
fill: value.mark === "bar" ? accent_color : undefined,
cursor: "crosshair"
}
},
data: { name: "data" },
datasets: {
data: _data
},
layer: ["plot", ...(value.mark === "line" ? ["hover"] : [])].map(
(mode) => {
return {
encoding: {
size:
value.mark === "line"
? mode == "plot"
? {
condition: {
empty: false,
param: "hoverPlot",
value: 3
},
value: 2
}
: {
condition: { empty: false, param: "hover", value: 100 },
value: 0
}
: undefined,
opacity:
mode === "plot"
? undefined
: {
condition: { empty: false, param: "hover", value: 1 },
value: 0
},
x: {
axis: {},
field: x,
title: x_title || x,
type: value.datatypes[x],
scale: _x_lim ? { domain: _x_lim } : undefined,
bin: _x_bin ? { step: _x_bin } : undefined,
sort: _sort
},
y: {
axis: {},
field: y,
title: y_title || y,
type: value.datatypes[y],
scale: y_lim ? { domain: y_lim } : undefined,
aggregate: aggregating ? _y_aggregate : undefined
},
color: color
? {
field: color,
legend: { orient: "bottom", title: color_title },
scale:
value.datatypes[color] === "nominal"
? {
domain: unique_colors,
range: color_map
? unique_colors.map((c) => color_map[c])
: undefined
}
: {
range: [
100, 200, 300, 400, 500, 600, 700, 800, 900
].map((n) =>
computed_style.getPropertyValue("--primary-" + n)
),
interpolate: "hsl"
},
type: value.datatypes[color]
}
: undefined,
tooltip: [
{
field: y,
type: value.datatypes[y],
aggregate: aggregating ? _y_aggregate : undefined,
title: y_title || y
},
{
field: x,
type: value.datatypes[x],
title: x_title || x,
format: x_temporal ? "%Y-%m-%d %H:%M:%S" : undefined,
bin: _x_bin ? { step: _x_bin } : undefined
},
...(color
? [
{
field: color,
type: value.datatypes[color]
}
]
: [])
]
},
strokeDash: {},
mark: { clip: true, type: mode === "hover" ? "point" : value.mark },
name: mode
};
}
),
// @ts-ignore
params: [
...(value.mark === "line"
? [
{
name: "hoverPlot",
select: {
clear: "mouseout",
fields: color ? [color] : [],
nearest: true,
on: "mouseover",
type: "point" as "point"
},
views: ["hover"]
},
{
name: "hover",
select: {
clear: "mouseout",
nearest: true,
on: "mouseover",
type: "point" as "point"
},
views: ["hover"]
}
]
: []),
...(_selectable
? [
{
name: "brush",
select: {
encodings: ["x"],
mark: { fill: "gray", fillOpacity: 0.3, stroke: "none" },
type: "interval" as "interval"
},
views: ["plot"]
}
]
: [])
],
width: chart_element.offsetWidth,
title: title || undefined
};
}
export let label = "Textbox";
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
export let show_label: boolean;
export let scale: number | null = null;
export let min_width: number | undefined = undefined;
export let loading_status: LoadingStatus | undefined = undefined;
export let height: number | undefined = undefined;
</script>
<Block
{visible}
{elem_id}
{elem_classes}
{scale}
{min_width}
allow_overflow={false}
padding={true}
{height}
>
{#if loading_status}
<StatusTracker
autoscroll={gradio.autoscroll}
i18n={gradio.i18n}
{...loading_status}
on:clear_status={() => gradio.dispatch("clear_status", loading_status)}
/>
{/if}
<BlockTitle {show_label} info={undefined}>{label}</BlockTitle>
<div bind:this={chart_element}></div>
{#if value}
{#if caption}
<p class="caption">{caption}</p>
{/if}
{:else}
<Empty unpadded_box={true}><LabelIcon /></Empty>
{/if}
</Block>
<style>
div {
width: 100%;
}
:global(#vg-tooltip-element) {
font-family: var(--font) !important;
font-size: var(--text-xs) !important;
box-shadow: none !important;
background-color: var(--block-background-fill) !important;
border: 1px solid var(--border-color-primary) !important;
color: var(--body-text-color) !important;
}
:global(#vg-tooltip-element .key) {
color: var(--body-text-color-subdued) !important;
}
.caption {
padding: 0 4px;
margin: 0;
text-align: center;
}
</style>

View File

@ -0,0 +1,28 @@
{
"name": "@gradio/nativeplot",
"version": "0.0.1",
"description": "Gradio UI packages",
"type": "module",
"author": "",
"license": "ISC",
"private": false,
"main_changeset": true,
"exports": {
".": "./Index.svelte",
"./example": "./Example.svelte",
"./package.json": "./package.json"
},
"dependencies": {
"@gradio/atoms": "workspace:^",
"@gradio/icons": "workspace:^",
"@gradio/statustracker": "workspace:^",
"@gradio/utils": "workspace:^",
"@gradio/theme": "workspace:^",
"vega": "^5.23.0",
"vega-embed": "^6.25.0",
"vega-lite": "^5.12.0"
},
"devDependencies": {
"@gradio/preview": "workspace:^"
}
}

View File

@ -122,6 +122,7 @@
"@gradio/markdown": "workspace:^",
"@gradio/model3d": "workspace:^",
"@gradio/multimodaltextbox": "workspace:^",
"@gradio/nativeplot": "workspace:^",
"@gradio/number": "workspace:^",
"@gradio/paramviewer": "workspace:^",
"@gradio/plot": "workspace:^",

47
pnpm-lock.yaml generated
View File

@ -270,6 +270,9 @@ importers:
'@gradio/multimodaltextbox':
specifier: workspace:^
version: link:js/multimodaltextbox
'@gradio/nativeplot':
specifier: workspace:^
version: link:js/nativeplot
'@gradio/number':
specifier: workspace:^
version: link:js/number
@ -625,6 +628,9 @@ importers:
'@gradio/multimodaltextbox':
specifier: workspace:^
version: link:../multimodaltextbox
'@gradio/nativeplot':
specifier: workspace:^
version: link:../nativeplot
'@gradio/number':
specifier: workspace:^
version: link:../number
@ -1549,6 +1555,37 @@ importers:
specifier: workspace:^
version: link:../preview
js/nativeplot:
dependencies:
'@gradio/atoms':
specifier: workspace:^
version: link:../atoms
'@gradio/icons':
specifier: workspace:^
version: link:../icons
'@gradio/statustracker':
specifier: workspace:^
version: link:../statustracker
'@gradio/theme':
specifier: workspace:^
version: link:../theme
'@gradio/utils':
specifier: workspace:^
version: link:../utils
vega:
specifier: ^5.23.0
version: 5.23.0
vega-embed:
specifier: ^6.25.0
version: 6.25.0(vega-lite@5.12.0(vega@5.23.0))(vega@5.23.0)
vega-lite:
specifier: ^5.12.0
version: 5.12.0(vega@5.23.0)
devDependencies:
'@gradio/preview':
specifier: workspace:^
version: link:../preview
js/number:
dependencies:
'@gradio/atoms':
@ -6025,10 +6062,6 @@ packages:
engines: {node: '>=12'}
hasBin: true
escalade@3.1.1:
resolution: {integrity: sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==}
engines: {node: '>=6'}
escalade@3.1.2:
resolution: {integrity: sha512-ErCHMCae19vR8vQGe50xIsVomy19rg6gFu3+r3jkEO46suLMWBksvVyoGgQV+jOfl84ZSOSlmv6Gxa89PmTGmA==}
engines: {node: '>=6'}
@ -14422,8 +14455,6 @@ snapshots:
'@esbuild/win32-ia32': 0.21.0
'@esbuild/win32-x64': 0.21.0
escalade@3.1.1: {}
escalade@3.1.2: {}
escape-html@1.0.3: {}
@ -18217,7 +18248,7 @@ snapshots:
d3-timer: 3.0.1
vega-dataflow: 5.7.5
vega-format: 1.1.1
vega-functions: 5.13.2
vega-functions: 5.14.0
vega-runtime: 6.1.4
vega-scenegraph: 4.10.2
vega-util: 1.17.2
@ -18565,7 +18596,7 @@ snapshots:
yargs@17.7.2:
dependencies:
cliui: 8.0.1
escalade: 3.1.1
escalade: 3.1.2
get-caller-file: 2.0.5
require-directory: 2.1.1
string-width: 4.2.3

View File

@ -1,5 +1,4 @@
aiofiles>=22.0,<24.0
altair>=5.0,<6.0
anyio>=3.0,<5.0
fastapi
ffmpy

View File

@ -1,12 +1,14 @@
from datetime import datetime, timedelta
import pandas as pd
import vega_datasets
cars = vega_datasets.data.cars()
stocks = vega_datasets.data.stocks()
barley = vega_datasets.data.barley()
simple = pd.DataFrame(
{
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
"b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
"c": [0] * 9,
"date": [datetime.now() - timedelta(days=i) for i in range(9)],
}
)

View File

@ -1,118 +0,0 @@
import json
from unittest.mock import MagicMock, patch
import gradio as gr
from .plot_data import barley, simple
class TestBarPlot:
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
def test_get_config(self):
assert gr.BarPlot().get_config() == {
"caption": None,
"elem_id": None,
"elem_classes": [],
"interactive": None,
"label": None,
"name": "barplot",
"bokeh_version": "3.0.3",
"show_actions_button": False,
"proxy_url": None,
"show_label": False,
"container": True,
"min_width": 160,
"scale": None,
"value": None,
"visible": True,
"x": None,
"y": None,
"color": None,
"vertical": True,
"group": None,
"title": None,
"tooltip": None,
"x_title": None,
"y_title": None,
"color_legend_title": None,
"group_title": None,
"color_legend_position": None,
"height": None,
"width": None,
"y_lim": None,
"x_label_angle": None,
"y_label_angle": None,
"sort": None,
"_selectable": False,
"key": None,
}
def test_no_color(self):
plot = gr.BarPlot(
x="a",
y="b",
tooltip=["a", "b"],
title="Made Up Bar Plot",
x_title="Variable A",
sort="x",
)
assert (output := plot.postprocess(simple))
output = output.model_dump()
assert sorted(output.keys()) == ["chart", "plot", "type"]
assert output["chart"] == "bar"
config = json.loads(output["plot"])
assert config["encoding"]["x"]["sort"] == "x"
assert config["encoding"]["x"]["field"] == "a"
assert config["encoding"]["x"]["title"] == "Variable A"
assert config["encoding"]["y"]["field"] == "b"
assert config["encoding"]["y"]["title"] == "b"
assert config["title"] == "Made Up Bar Plot"
assert "height" not in config
assert "width" not in config
def test_height_width(self):
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
assert (output := plot.postprocess(simple))
output = output.model_dump()
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200
def test_ylim(self):
plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100])
assert (output := plot.postprocess(simple))
output = output.model_dump()
config = json.loads(output["plot"])
assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]}
def test_horizontal(self):
output = gr.BarPlot(
simple,
x="a",
y="b",
x_title="Variable A",
y_title="Variable B",
title="Simple Bar Plot with made up data",
tooltip=["a", "b"],
vertical=False,
y_lim=[20, 100],
).get_config()
assert output["value"]["chart"] == "bar"
config = json.loads(output["value"]["plot"])
assert config["encoding"]["x"]["field"] == "b"
assert config["encoding"]["x"]["scale"] == {"domain": [20, 100]}
assert config["encoding"]["x"]["title"] == "Variable B"
assert config["encoding"]["y"]["field"] == "a"
assert config["encoding"]["y"]["title"] == "Variable A"
def test_barplot_accepts_fn_as_value(self):
plot = gr.BarPlot(
value=lambda: barley.sample(frac=0.1, replace=False),
x="year",
y="yield",
)
assert isinstance(plot.value, dict)
assert isinstance(plot.value["plot"], str)

View File

@ -1,112 +0,0 @@
import json
from unittest.mock import MagicMock, patch
import gradio as gr
from .plot_data import stocks
class TestLinePlot:
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
def test_get_config(self):
assert gr.LinePlot().get_config() == {
"caption": None,
"elem_id": None,
"elem_classes": [],
"interactive": None,
"label": None,
"name": "lineplot",
"bokeh_version": "3.0.3",
"show_actions_button": False,
"proxy_url": None,
"show_label": False,
"container": True,
"min_width": 160,
"scale": None,
"value": None,
"visible": True,
"x": None,
"y": None,
"color": None,
"stroke_dash": None,
"overlay_point": None,
"title": None,
"tooltip": [],
"x_title": None,
"y_title": None,
"color_legend_title": None,
"stroke_dash_legend_title": None,
"color_legend_position": None,
"stroke_dash_legend_position": None,
"height": None,
"width": None,
"x_lim": None,
"y_lim": None,
"x_label_angle": None,
"y_label_angle": None,
"_selectable": False,
"key": None,
}
def test_no_color(self):
plot = gr.LinePlot(
x="date",
y="price",
tooltip=["symbol", "price"],
title="Stock Performance",
x_title="Trading Day",
)
output = plot.postprocess(stocks).model_dump() # type: ignore
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
for layer in config["layer"]:
assert layer["mark"]["type"] in ["line", "point"]
assert layer["encoding"]["x"]["field"] == "date"
assert layer["encoding"]["x"]["title"] == "Trading Day"
assert layer["encoding"]["y"]["field"] == "price"
assert config["title"] == "Stock Performance"
assert "height" not in config
assert "width" not in config
def test_height_width(self):
plot = gr.LinePlot(x="date", y="price", height=100, width=200)
output = plot.postprocess(stocks).model_dump() # type: ignore
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200
def test_xlim_ylim(self):
plot = gr.LinePlot(x="date", y="price", x_lim=[200, 400], y_lim=[300, 500])
output = plot.postprocess(stocks).model_dump() # type: ignore
config = json.loads(output["plot"])
for layer in config["layer"]:
assert layer["encoding"]["x"]["scale"] == {"domain": [200, 400]}
assert layer["encoding"]["y"]["scale"] == {"domain": [300, 500]}
def test_color_encoding(self):
plot = gr.LinePlot(
x="date", y="price", tooltip="symbol", color="symbol", overlay_point=True
)
output = plot.postprocess(stocks).model_dump() # type: ignore
config = json.loads(output["plot"])
for layer in config["layer"]:
assert layer["encoding"]["color"]["field"] == "symbol"
assert layer["encoding"]["color"]["scale"] == {
"domain": ["MSFT", "AMZN", "IBM", "GOOG", "AAPL"],
"range": [0, 1, 2, 3, 4],
}
assert layer["encoding"]["color"]["type"] == "nominal"
if layer["mark"]["type"] == "point":
assert layer["encoding"]["opacity"] == {}
def test_lineplot_accepts_fn_as_value(self):
plot = gr.LinePlot(
value=lambda: stocks.sample(frac=0.1, replace=False),
x="date",
y="price",
color="symbol",
)
assert isinstance(plot.value, dict)
assert isinstance(plot.value["plot"], str)

View File

@ -0,0 +1,32 @@
import gradio as gr
from .plot_data import barley, simple
class TestNativePlot:
def test_plot_recognizes_correct_datatypes(self):
plot = gr.BarPlot(
value=simple,
x="date",
y="b",
)
assert plot.value["datatypes"]["date"] == "temporal"
assert plot.value["datatypes"]["b"] == "quantitative"
plot = gr.BarPlot(
value=simple,
x="a",
y="b",
color="c",
)
assert plot.value["datatypes"]["a"] == "nominal"
assert plot.value["datatypes"]["b"] == "quantitative"
assert plot.value["datatypes"]["c"] == "quantitative"
def test_plot_accepts_fn_as_value(self):
plot = gr.BarPlot(
value=lambda: barley.sample(frac=0.1, replace=False),
x="year",
y="yield",
)
assert plot.value["mark"] == "bar"

View File

@ -1,168 +0,0 @@
import json
from unittest.mock import MagicMock, patch
import gradio as gr
from .plot_data import cars
class TestScatterPlot:
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
def test_get_config(self):
print(gr.ScatterPlot().get_config())
assert gr.ScatterPlot().get_config() == {
"caption": None,
"elem_id": None,
"elem_classes": [],
"interactive": None,
"label": None,
"name": "scatterplot",
"bokeh_version": "3.0.3",
"show_actions_button": False,
"proxy_url": None,
"show_label": False,
"container": True,
"min_width": 160,
"scale": None,
"value": None,
"visible": True,
"x": None,
"y": None,
"color": None,
"size": None,
"shape": None,
"title": None,
"tooltip": None,
"x_title": None,
"y_title": None,
"color_legend_title": None,
"size_legend_title": None,
"shape_legend_title": None,
"color_legend_position": None,
"size_legend_position": None,
"shape_legend_position": None,
"height": None,
"width": None,
"x_lim": None,
"y_lim": None,
"x_label_angle": None,
"y_label_angle": None,
"_selectable": False,
"key": None,
}
def test_no_color(self):
plot = gr.ScatterPlot(
x="Horsepower",
y="Miles_per_Gallon",
tooltip="Name",
title="Car Data",
x_title="Horse",
)
output = plot.postprocess(cars).model_dump() # type: ignore
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["encoding"]["x"]["field"] == "Horsepower"
assert config["encoding"]["x"]["title"] == "Horse"
assert config["encoding"]["y"]["field"] == "Miles_per_Gallon"
assert config["title"] == "Car Data"
assert "height" not in config
assert "width" not in config
def test_no_interactive(self):
plot = gr.ScatterPlot(
x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False
)
output = plot.postprocess(cars).model_dump() # type: ignore
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert "selection" not in config
def test_height_width(self):
plot = gr.ScatterPlot(
x="Horsepower", y="Miles_per_Gallon", height=100, width=200
)
output = plot.postprocess(cars).model_dump() # type: ignore
assert sorted(output.keys()) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200
def test_xlim_ylim(self):
plot = gr.ScatterPlot(
x="Horsepower", y="Miles_per_Gallon", x_lim=[200, 400], y_lim=[300, 500]
)
output = plot.postprocess(cars).model_dump() # type: ignore
config = json.loads(output["plot"])
assert config["encoding"]["x"]["scale"] == {"domain": [200, 400]}
assert config["encoding"]["y"]["scale"] == {"domain": [300, 500]}
def test_color_encoding(self):
plot = gr.ScatterPlot(
x="Horsepower",
y="Miles_per_Gallon",
tooltip="Name",
title="Car Data",
color="Origin",
)
output = plot.postprocess(cars).model_dump() # type: ignore
config = json.loads(output["plot"])
assert config["encoding"]["color"]["field"] == "Origin"
assert config["encoding"]["color"]["scale"] == {
"domain": ["USA", "Europe", "Japan"],
"range": [0, 1, 2],
}
assert config["encoding"]["color"]["type"] == "nominal"
def test_two_encodings(self):
plot = gr.ScatterPlot(
show_label=False,
title="Two encodings",
x="Horsepower",
y="Miles_per_Gallon",
color="Acceleration",
shape="Origin",
)
output = plot.postprocess(cars).model_dump() # type: ignore
config = json.loads(output["plot"])
assert config["encoding"]["color"]["field"] == "Acceleration"
assert config["encoding"]["color"]["scale"] == {
"domain": [cars.Acceleration.min(), cars.Acceleration.max()],
"range": [0, 1],
}
assert config["encoding"]["color"]["type"] == "quantitative"
assert config["encoding"]["shape"]["field"] == "Origin"
assert config["encoding"]["shape"]["type"] == "nominal"
def test_legend_position(self):
plot = gr.ScatterPlot(
show_label=False,
title="Two encodings",
x="Horsepower",
y="Miles_per_Gallon",
color="Acceleration",
color_legend_position="none",
color_legend_title="Foo",
shape="Origin",
shape_legend_position="none",
shape_legend_title="Bar",
size="Acceleration",
size_legend_title="Accel",
size_legend_position="none",
)
output = plot.postprocess(cars).model_dump() # type: ignore
config = json.loads(output["plot"])
assert config["encoding"]["color"]["legend"] is None
assert config["encoding"]["shape"]["legend"] is None
assert config["encoding"]["size"]["legend"] is None
def test_scatterplot_accepts_fn_as_value(self):
plot = gr.ScatterPlot(
value=lambda: cars.sample(frac=0.1, replace=False),
x="Horsepower",
y="Miles_per_Gallon",
color="Origin",
)
assert isinstance(plot.value, dict)
assert isinstance(plot.value["plot"], str)