Bar Plot Component (#3157)

* Add code - first draft

* Getting better

* Work out the bugs

* Fix docstrings

* CHANGELOG

* Fix test

* Generate notebooks

* Add unit test

* Undo website package.lock

* Fix demo

* Fix notebooks

* Fix docstrings

* Improve example in CHANGELOG

* Address comments + feedback

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-02-09 16:42:25 -05:00 committed by GitHub
parent f92109621a
commit c06b4eab16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 628 additions and 24 deletions

View File

@ -1,7 +1,38 @@
# Upcoming Release
## New Features:
No changes to highlight.
### New `gr.BarPlot` component! 📊
Create interactive bar plots from a high-level interface with `gr.BarPlot`.
No need to remember matplotlib syntax anymore!
Example usage:
```python
import gradio as gr
import pandas as pd
simple = pd.DataFrame({
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
})
with gr.Blocks() as demo:
gr.BarPlot(
simple,
x="a",
y="b",
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
)
demo.launch()
```
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3157](https://github.com/gradio-app/gradio/pull/3157)
## Bug Fixes:
No changes to highlight.

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " fig_m, ax = plt.subplots()\n", " ax.bar(x=df[\"rideable_type\"], height=df[\"n\"])\n", " ax.set_title(\"Number of rides by bycycle type\")\n", " ax.set_ylabel(\"Number of Rides\")\n", " ax.set_xlabel(\"Bicycle Type\")\n", " return fig_m\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " fig_m, ax = plt.subplots()\n", " ax.bar(x=df[\"station\"], height=df[\"n\"])\n", " ax.set_title(\"Most popular stations\")\n", " ax.set_ylabel(\"Number of Rides\")\n", " ax.set_xlabel(\"Station Name\")\n", " ax.set_xticklabels(df[\"station\"], rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n", " ax.tick_params(axis=\"x\", labelsize=8)\n", " fig_m.tight_layout()\n", " return fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.Plot()\n", " station = gr.Plot()\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.BarPlot(\n", " x=\"rideable_type\",\n", " y='n',\n", " title=\"Number of rides per bicycle type\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Bicycle Type\",\n", " vertical=False,\n", " tooltip=['rideable_type', \"n\"],\n", " height=300,\n", " width=300,\n", " )\n", " station = gr.BarPlot(\n", " x='station',\n", " y='n',\n", " title=\"Most Popular Stations\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Station Name\",\n", " vertical=False,\n", " tooltip=['station', 'n'],\n", " height=300,\n", " width=300\n", " )\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -27,12 +27,7 @@ def get_count_ride_type():
""",
con=connection_string,
)
fig_m, ax = plt.subplots()
ax.bar(x=df["rideable_type"], height=df["n"])
ax.set_title("Number of rides by bycycle type")
ax.set_ylabel("Number of Rides")
ax.set_xlabel("Bicycle Type")
return fig_m
return df
def get_most_popular_stations():
@ -48,15 +43,7 @@ def get_most_popular_stations():
""",
con=connection_string,
)
fig_m, ax = plt.subplots()
ax.bar(x=df["station"], height=df["n"])
ax.set_title("Most popular stations")
ax.set_ylabel("Number of Rides")
ax.set_xlabel("Station Name")
ax.set_xticklabels(df["station"], rotation=45, ha="right", rotation_mode="anchor")
ax.tick_params(axis="x", labelsize=8)
fig_m.tight_layout()
return fig_m
return df
with gr.Blocks() as demo:
@ -78,8 +65,28 @@ with gr.Blocks() as demo:
"""
)
with gr.Row():
bike_type = gr.Plot()
station = gr.Plot()
bike_type = gr.BarPlot(
x="rideable_type",
y='n',
title="Number of rides per bicycle type",
y_title="Number of Rides",
x_title="Bicycle Type",
vertical=False,
tooltip=['rideable_type', "n"],
height=300,
width=300,
)
station = gr.BarPlot(
x='station',
y='n',
title="Most Popular Stations",
y_title="Number of Rides",
x_title="Station Name",
vertical=False,
tooltip=['station', 'n'],
height=300,
width=300
)
demo.load(get_count_ride_type, inputs=None, outputs=bike_type)
demo.load(get_most_popular_stations, inputs=None, outputs=station)

View File

@ -0,0 +1,89 @@
import gradio as gr
import pandas as pd
from vega_datasets import data
barley = data.barley()
simple = pd.DataFrame({
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
})
def bar_plot_fn(display):
if display == "simple":
return gr.BarPlot.update(
simple,
x="a",
y="b",
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
y_lim=[20, 100]
)
elif display == "stacked":
return gr.BarPlot.update(
barley,
x="variety",
y="yield",
color="site",
title="Barley Yield Data",
tooltip=['variety', 'site']
)
elif display == "grouped":
return gr.BarPlot.update(
barley.astype({"year": str}),
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
group_title="",
tooltip=["yield", "site", "year"]
)
elif display == "simple-horizontal":
return gr.BarPlot.update(
simple,
x="a",
y="b",
x_title="Variable A",
y_title="Variable B",
title="Simple Bar Plot with made up data",
tooltip=['a', 'b'],
vertical=False,
y_lim=[20, 100]
)
elif display == "stacked-horizontal":
return gr.BarPlot.update(
barley,
x="variety",
y="yield",
color="site",
title="Barley Yield Data",
vertical=False,
tooltip=['variety', 'site']
)
elif display == "grouped-horizontal":
return gr.BarPlot.update(
barley.astype({"year": str}),
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
group_title="",
tooltip=["yield", "site", "year"],
vertical=False
)
with gr.Blocks() as bar_plot:
with gr.Row():
with gr.Column():
display = gr.Dropdown(
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
value="simple",
label="Type of Bar Plot"
)
with gr.Column():
plot = gr.BarPlot(show_label=False).style(container=True)
display.change(bar_plot_fn, inputs=display, outputs=plot)
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "from bar_plot_demo import bar_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -2,6 +2,7 @@ import gradio as gr
from scatter_plot_demo import scatter_plot
from line_plot_demo import line_plot
from bar_plot_demo import bar_plot
with gr.Blocks() as demo:
@ -10,6 +11,8 @@ with gr.Blocks() as demo:
scatter_plot.render()
with gr.TabItem("Line Plot"):
line_plot.render()
with gr.TabItem("Bar Plot"):
bar_plot.render()
if __name__ == "__main__":
demo.launch()

View File

@ -10,6 +10,7 @@ from gradio.components import (
HTML,
JSON,
Audio,
BarPlot,
Button,
Carousel,
Chatbot,

View File

@ -4371,11 +4371,11 @@ class ScatterPlot(Plot):
color_legend_position,
size_legend_position,
shape_legend_position,
interactive,
height,
width,
x_lim,
y_lim,
interactive,
]
if any(properties):
if not isinstance(value, pd.DataFrame):
@ -4591,8 +4591,8 @@ class LinePlot(Plot):
caption: The (optional) caption to display below the plot.
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
label: The (optional) label to display on the top left corner of the plot.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
show_label: Whether the label should be displayed.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
visible: Whether the plot should be visible.
elem_id: Unique id used for custom css targetting.
"""
@ -4863,6 +4863,320 @@ class LinePlot(Plot):
return {"type": "altair", "plot": chart.to_json(), "chart": "line"}
@document("change", "clear")
class BarPlot(Plot):
"""
Create a bar plot.
Preprocessing: this component does *not* accept input.
Postprocessing: expects a pandas dataframe with the data to plot.
Demos: native_plots, chicago-bikeshare-dashboard
"""
def __init__(
self,
value: pd.DataFrame | Callable | None = None,
x: str | None = None,
y: str | None = None,
*,
color: str | None = None,
vertical: bool = True,
group: str | None = None,
title: str | None = None,
tooltip: List[str] | str | None = None,
x_title: str | None = None,
y_title: str | None = None,
color_legend_title: str | None = None,
group_title: str | None = None,
color_legend_position: str | None = None,
height: int | None = None,
width: int | None = None,
y_lim: List[int] | None = None,
caption: str | None = None,
interactive: bool | None = True,
label: str | None = None,
show_label: bool = True,
every: float | None = None,
visible: bool = True,
elem_id: str | None = None,
):
"""
Parameters:
value: The pandas dataframe containing the data to display in a scatter plot.
x: Column corresponding to the x axis.
y: Column corresponding to the y axis.
color: The column to determine the bar color. Must be categorical (discrete values).
vertical: If True, the bars will be displayed vertically. If False, the x and y axis will be switched, displaying the bars horizontally. Default is True.
group: The column with which to split the overall plot into smaller subplots.
title: The title to display on top of the chart.
tooltip: The column (or list of columns) to display on the tooltip when a user hovers over a bar.
x_title: The title given to the x axis. By default, uses the value of the x parameter.
y_title: The title given to the y axis. By default, uses the value of the y parameter.
color_legend_title: The title given to the color legend. By default, uses the value of color parameter.
group_title: The label displayed on top of the subplot columns (or rows if vertical=True). Use an empty string to omit.
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
height: The height of the plot in pixels.
width: The width of the plot in pixels.
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
caption: The (optional) caption to display below the plot.
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
label: The (optional) label to display on the top left corner of the plot.
show_label: Whether the label should be displayed.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
visible: Whether the plot should be visible.
elem_id: Unique id used for custom css targetting.
"""
self.x = x
self.y = y
self.color = color
self.vertical = vertical
self.group = group
self.group_title = group_title
self.tooltip = tooltip
self.title = title
self.x_title = x_title
self.y_title = y_title
self.color_legend_title = color_legend_title
self.group_title = group_title
self.color_legend_position = color_legend_position
self.y_lim = y_lim
self.caption = caption
self.interactive_chart = interactive
self.width = width
self.height = height
super().__init__(
value=value,
label=label,
show_label=show_label,
visible=visible,
elem_id=elem_id,
every=every,
)
def get_config(self):
config = super().get_config()
config["caption"] = self.caption
return config
def get_block_name(self) -> str:
return "plot"
@staticmethod
def update(
value: pd.DataFrame | Dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE,
x: str | None = None,
y: str | None = None,
color: str | None = None,
vertical: bool = True,
group: str | None = None,
title: str | None = None,
tooltip: List[str] | str | None = None,
x_title: str | None = None,
y_title: str | None = None,
color_legend_title: str | None = None,
group_title: str | None = None,
color_legend_position: str | None = None,
height: int | None = None,
width: int | None = None,
y_lim: List[int] | None = None,
caption: str | None = None,
interactive: bool | None = True,
label: str | None = None,
show_label: bool = True,
visible: bool = True,
):
"""Update an existing BarPlot component.
If updating any of the plot properties (color, size, etc) the value, x, and y parameters must be specified.
Parameters:
value: The pandas dataframe containing the data to display in a scatter plot.
x: Column corresponding to the x axis.
y: Column corresponding to the y axis.
color: The column to determine the bar color. Must be categorical (discrete values).
vertical: If True, the bars will be displayed vertically. If False, the x and y axis will be switched, displaying the bars horizontally. Default is True.
group: The column with which to split the overall plot into smaller subplots.
title: The title to display on top of the chart.
tooltip: The column (or list of columns) to display on the tooltip when a user hovers over a bar.
x_title: The title given to the x axis. By default, uses the value of the x parameter.
y_title: The title given to the y axis. By default, uses the value of the y parameter.
color_legend_title: The title given to the color legend. By default, uses the value of color parameter.
group_title: The label displayed on top of the subplot columns (or rows if vertical=True). Use an empty string to omit.
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
height: The height of the plot in pixels.
width: The width of the plot in pixels.
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
caption: The (optional) caption to display below the plot.
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
label: The (optional) label to display on the top left corner of the plot.
show_label: Whether the label should be displayed.
visible: Whether the plot should be visible.
"""
properties = [
x,
y,
color,
vertical,
group,
title,
tooltip,
x_title,
y_title,
color_legend_title,
group_title,
color_legend_position,
height,
width,
y_lim,
interactive,
]
if any(properties):
if not isinstance(value, pd.DataFrame):
raise ValueError(
"In order to update plot properties the value parameter "
"must be provided, and it must be a Dataframe. Please pass a value "
"parameter to gr.BarPlot.update."
)
if x is None or y is None:
raise ValueError(
"In order to update plot properties, the x and y axis data "
"must be specified. Please pass valid values for x an y to "
"gr.BarPlot.update."
)
chart = BarPlot.create_plot(value, *properties)
value = {"type": "altair", "plot": chart.to_json(), "chart": "bar"}
updated_config = {
"label": label,
"show_label": show_label,
"visible": visible,
"value": value,
"caption": caption,
"__type__": "update",
}
return updated_config
@staticmethod
def create_plot(
value: pd.DataFrame,
x: str,
y: str,
color: str | None = None,
vertical: bool = True,
group: str | None = None,
title: str | None = None,
tooltip: List[str] | str | None = None,
x_title: str | None = None,
y_title: str | None = None,
color_legend_title: str | None = None,
group_title: str | None = None,
color_legend_position: str | None = None,
height: int | None = None,
width: int | None = None,
y_lim: List[int] | None = None,
interactive: bool | None = True,
):
"""Helper for creating the scatter plot."""
interactive = True if interactive is None else interactive
orientation = (
dict(field=group, title=group_title if group_title is not None else group)
if group
else {}
)
x_title = x_title or x
y_title = y_title or y
# If horizontal, switch x and y
if not vertical:
y, x = x, y
x = f"sum({x}):Q"
y_title, x_title = x_title, y_title
orientation = {"row": alt.Row(**orientation)} if orientation else {} # type: ignore
x_lim = y_lim
y_lim = None
else:
y = f"sum({y}):Q"
x_lim = None
orientation = {"column": alt.Column(**orientation)} if orientation else {} # type: ignore
encodings = dict(
x=alt.X(
x, # type: ignore
title=x_title, # type: ignore
scale=AltairPlot.create_scale(x_lim), # type: ignore
),
y=alt.Y(
y, # type: ignore
title=y_title, # type: ignore
scale=AltairPlot.create_scale(y_lim), # type: ignore
),
**orientation,
)
properties = {}
if title:
properties["title"] = title
if height:
properties["height"] = height
if width:
properties["width"] = width
if color:
domain = value[color].unique().tolist()
range_ = list(range(len(domain)))
encodings["color"] = {
"field": color,
"type": "nominal",
"scale": {"domain": domain, "range": range_},
"legend": AltairPlot.create_legend(
position=color_legend_position, title=color_legend_title or color
),
}
if tooltip:
encodings["tooltip"] = tooltip
chart = (
alt.Chart(value) # type: ignore
.mark_bar() # type: ignore
.encode(**encodings)
.properties(background="transparent", **properties)
)
if interactive:
chart = chart.interactive()
return chart
def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None:
# if None or update
if y is None or isinstance(y, Dict):
return y
if self.x is None or self.y is None:
raise ValueError("No value provided for required parameters `x` and `y`.")
chart = self.create_plot(
value=y,
x=self.x,
y=self.y,
color=self.color,
vertical=self.vertical,
group=self.group,
title=self.title,
tooltip=self.tooltip,
x_title=self.x_title,
y_title=self.y_title,
color_legend_title=self.color_legend_title,
color_legend_position=self.color_legend_position,
group_title=self.group_title,
y_lim=self.y_lim,
interactive=self.interactive_chart,
height=self.height,
width=self.width,
)
return {"type": "altair", "plot": chart.to_json(), "chart": "bar"}
@document("change")
class Markdown(IOComponent, Changeable, SimpleSerializable):
"""

View File

@ -418,7 +418,7 @@ class TestComponentsInBlocks:
io_components = [
c()
for c in io_components
if c not in [gr.State, gr.Button, gr.ScatterPlot, gr.LinePlot]
if c not in [gr.State, gr.Button, gr.ScatterPlot, gr.LinePlot, gr.BarPlot]
]
with gr.Blocks() as demo:
for component in io_components:

View File

@ -2005,6 +2005,13 @@ def test_dataset_calls_as_example(*mocks):
cars = vega_datasets.data.cars()
stocks = vega_datasets.data.stocks()
barley = vega_datasets.data.barley()
simple = pd.DataFrame(
{
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
"b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
}
)
class TestScatterPlot:
@ -2348,3 +2355,146 @@ class TestLinePlot:
)
assert isinstance(plot.value, dict)
assert isinstance(plot.value["plot"], str)
class TestBarPlot:
def test_get_config(self):
assert gr.BarPlot().get_config() == {
"caption": None,
"elem_id": None,
"interactive": None,
"label": None,
"name": "plot",
"root_url": None,
"show_label": True,
"style": {},
"value": None,
"visible": True,
}
def test_no_color(self):
plot = gr.BarPlot(
x="a",
y="b",
tooltip=["a", "b"],
title="Made Up Bar Plot",
x_title="Variable A",
)
output = plot.postprocess(simple)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
assert output["chart"] == "bar"
config = json.loads(output["plot"])
assert config["encoding"]["x"]["field"] == "a"
assert config["encoding"]["x"]["title"] == "Variable A"
assert config["encoding"]["y"]["field"] == "b"
assert config["encoding"]["y"]["title"] == "b"
assert config["title"] == "Made Up Bar Plot"
assert "height" not in config
assert "width" not in config
def test_height_width(self):
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
output = plot.postprocess(simple)
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
config = json.loads(output["plot"])
assert config["height"] == 100
assert config["width"] == 200
output = gr.BarPlot.update(simple, x="a", y="b", height=100, width=200)
config = json.loads(output["value"]["plot"])
assert config["height"] == 100
assert config["width"] == 200
def test_ylim(self):
plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100])
output = plot.postprocess(simple)
config = json.loads(output["plot"])
assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]}
def test_horizontal(self):
output = gr.BarPlot.update(
simple,
x="a",
y="b",
x_title="Variable A",
y_title="Variable B",
title="Simple Bar Plot with made up data",
tooltip=["a", "b"],
vertical=False,
y_lim=[20, 100],
)
assert output["value"]["chart"] == "bar"
config = json.loads(output["value"]["plot"])
assert config["encoding"]["x"]["field"] == "b"
assert config["encoding"]["x"]["scale"] == {"domain": [20, 100]}
assert config["encoding"]["x"]["title"] == "Variable B"
assert config["encoding"]["y"]["field"] == "a"
assert config["encoding"]["y"]["title"] == "Variable A"
def test_stack_via_color(self):
output = gr.BarPlot.update(
barley,
x="variety",
y="yield",
color="site",
title="Barley Yield Data",
color_legend_title="Site",
color_legend_position="bottom",
)
config = json.loads(output["value"]["plot"])
assert config["encoding"]["color"]["field"] == "site"
assert config["encoding"]["color"]["legend"] == {
"title": "Site",
"orient": "bottom",
}
assert config["encoding"]["color"]["scale"] == {
"domain": [
"University Farm",
"Waseca",
"Morris",
"Crookston",
"Grand Rapids",
"Duluth",
],
"range": [0, 1, 2, 3, 4, 5],
}
def test_group(self):
output = gr.BarPlot.update(
barley,
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
group_title="",
tooltip=["yield", "site", "year"],
)
config = json.loads(output["value"]["plot"])
assert config["encoding"]["column"] == {"field": "site", "title": ""}
def test_group_horizontal(self):
output = gr.BarPlot.update(
barley,
x="year",
y="yield",
color="year",
group="site",
title="Barley Yield by Year and Site",
group_title="Site Title",
tooltip=["yield", "site", "year"],
vertical=False,
)
config = json.loads(output["value"]["plot"])
assert config["encoding"]["row"] == {"field": "site", "title": "Site Title"}
def test_barplot_accepts_fn_as_value(self):
plot = gr.BarPlot(
value=lambda: barley.sample(frac=0.1, replace=False),
x="year",
y="yield",
)
assert isinstance(plot.value, dict)
assert isinstance(plot.value["plot"], str)

View File

@ -58,6 +58,14 @@
);
}
});
break;
case "bar":
if (spec.encoding.color) {
spec.encoding.color.scale.range = spec.encoding.color.scale.range.map(
(e, i) => get_color(i)
);
}
break;
default:
break;
}

View File

@ -26,7 +26,8 @@ export function create_config(darkmode: boolean): VegaConfig {
title: {
color: darkmode ? dark : light,
font: "sans-serif",
fontWeight: "normal"
fontWeight: "normal",
anchor: "middle"
}
};
}