mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Bar Plot Component (#3157)
* Add code - first draft * Getting better * Work out the bugs * Fix docstrings * CHANGELOG * Fix test * Generate notebooks * Add unit test * Undo website package.lock * Fix demo * Fix notebooks * Fix docstrings * Improve example in CHANGELOG * Address comments + feedback --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
f92109621a
commit
c06b4eab16
33
CHANGELOG.md
33
CHANGELOG.md
@ -1,7 +1,38 @@
|
||||
# Upcoming Release
|
||||
|
||||
## New Features:
|
||||
No changes to highlight.
|
||||
|
||||
### New `gr.BarPlot` component! 📊
|
||||
|
||||
Create interactive bar plots from a high-level interface with `gr.BarPlot`.
|
||||
No need to remember matplotlib syntax anymore!
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
|
||||
simple = pd.DataFrame({
|
||||
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
|
||||
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
|
||||
})
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.BarPlot(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=['a', 'b'],
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
|
||||
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3157](https://github.com/gradio-app/gradio/pull/3157)
|
||||
|
||||
|
||||
## Bug Fixes:
|
||||
No changes to highlight.
|
||||
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " fig_m, ax = plt.subplots()\n", " ax.bar(x=df[\"rideable_type\"], height=df[\"n\"])\n", " ax.set_title(\"Number of rides by bycycle type\")\n", " ax.set_ylabel(\"Number of Rides\")\n", " ax.set_xlabel(\"Bicycle Type\")\n", " return fig_m\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " fig_m, ax = plt.subplots()\n", " ax.bar(x=df[\"station\"], height=df[\"n\"])\n", " ax.set_title(\"Most popular stations\")\n", " ax.set_ylabel(\"Number of Rides\")\n", " ax.set_xlabel(\"Station Name\")\n", " ax.set_xticklabels(df[\"station\"], rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n", " ax.tick_params(axis=\"x\", labelsize=8)\n", " fig_m.tight_layout()\n", " return fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.Plot()\n", " station = gr.Plot()\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.BarPlot(\n", " x=\"rideable_type\",\n", " y='n',\n", " title=\"Number of rides per bicycle type\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Bicycle Type\",\n", " vertical=False,\n", " tooltip=['rideable_type', \"n\"],\n", " height=300,\n", " width=300,\n", " )\n", " station = gr.BarPlot(\n", " x='station',\n", " y='n',\n", " title=\"Most Popular Stations\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Station Name\",\n", " vertical=False,\n", " tooltip=['station', 'n'],\n", " height=300,\n", " width=300\n", " )\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -27,12 +27,7 @@ def get_count_ride_type():
|
||||
""",
|
||||
con=connection_string,
|
||||
)
|
||||
fig_m, ax = plt.subplots()
|
||||
ax.bar(x=df["rideable_type"], height=df["n"])
|
||||
ax.set_title("Number of rides by bycycle type")
|
||||
ax.set_ylabel("Number of Rides")
|
||||
ax.set_xlabel("Bicycle Type")
|
||||
return fig_m
|
||||
return df
|
||||
|
||||
|
||||
def get_most_popular_stations():
|
||||
@ -48,15 +43,7 @@ def get_most_popular_stations():
|
||||
""",
|
||||
con=connection_string,
|
||||
)
|
||||
fig_m, ax = plt.subplots()
|
||||
ax.bar(x=df["station"], height=df["n"])
|
||||
ax.set_title("Most popular stations")
|
||||
ax.set_ylabel("Number of Rides")
|
||||
ax.set_xlabel("Station Name")
|
||||
ax.set_xticklabels(df["station"], rotation=45, ha="right", rotation_mode="anchor")
|
||||
ax.tick_params(axis="x", labelsize=8)
|
||||
fig_m.tight_layout()
|
||||
return fig_m
|
||||
return df
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
@ -78,8 +65,28 @@ with gr.Blocks() as demo:
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
bike_type = gr.Plot()
|
||||
station = gr.Plot()
|
||||
bike_type = gr.BarPlot(
|
||||
x="rideable_type",
|
||||
y='n',
|
||||
title="Number of rides per bicycle type",
|
||||
y_title="Number of Rides",
|
||||
x_title="Bicycle Type",
|
||||
vertical=False,
|
||||
tooltip=['rideable_type', "n"],
|
||||
height=300,
|
||||
width=300,
|
||||
)
|
||||
station = gr.BarPlot(
|
||||
x='station',
|
||||
y='n',
|
||||
title="Most Popular Stations",
|
||||
y_title="Number of Rides",
|
||||
x_title="Station Name",
|
||||
vertical=False,
|
||||
tooltip=['station', 'n'],
|
||||
height=300,
|
||||
width=300
|
||||
)
|
||||
|
||||
demo.load(get_count_ride_type, inputs=None, outputs=bike_type)
|
||||
demo.load(get_most_popular_stations, inputs=None, outputs=station)
|
||||
|
89
demo/native_plots/bar_plot_demo.py
Normal file
89
demo/native_plots/bar_plot_demo.py
Normal file
@ -0,0 +1,89 @@
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
|
||||
from vega_datasets import data
|
||||
|
||||
barley = data.barley()
|
||||
simple = pd.DataFrame({
|
||||
'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'],
|
||||
'b': [28, 55, 43, 91, 81, 53, 19, 87, 52]
|
||||
})
|
||||
|
||||
def bar_plot_fn(display):
|
||||
if display == "simple":
|
||||
return gr.BarPlot.update(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=['a', 'b'],
|
||||
y_lim=[20, 100]
|
||||
)
|
||||
elif display == "stacked":
|
||||
return gr.BarPlot.update(
|
||||
barley,
|
||||
x="variety",
|
||||
y="yield",
|
||||
color="site",
|
||||
title="Barley Yield Data",
|
||||
tooltip=['variety', 'site']
|
||||
)
|
||||
elif display == "grouped":
|
||||
return gr.BarPlot.update(
|
||||
barley.astype({"year": str}),
|
||||
x="year",
|
||||
y="yield",
|
||||
color="year",
|
||||
group="site",
|
||||
title="Barley Yield by Year and Site",
|
||||
group_title="",
|
||||
tooltip=["yield", "site", "year"]
|
||||
)
|
||||
elif display == "simple-horizontal":
|
||||
return gr.BarPlot.update(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
x_title="Variable A",
|
||||
y_title="Variable B",
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=['a', 'b'],
|
||||
vertical=False,
|
||||
y_lim=[20, 100]
|
||||
)
|
||||
elif display == "stacked-horizontal":
|
||||
return gr.BarPlot.update(
|
||||
barley,
|
||||
x="variety",
|
||||
y="yield",
|
||||
color="site",
|
||||
title="Barley Yield Data",
|
||||
vertical=False,
|
||||
tooltip=['variety', 'site']
|
||||
)
|
||||
elif display == "grouped-horizontal":
|
||||
return gr.BarPlot.update(
|
||||
barley.astype({"year": str}),
|
||||
x="year",
|
||||
y="yield",
|
||||
color="year",
|
||||
group="site",
|
||||
title="Barley Yield by Year and Site",
|
||||
group_title="",
|
||||
tooltip=["yield", "site", "year"],
|
||||
vertical=False
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks() as bar_plot:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
display = gr.Dropdown(
|
||||
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
|
||||
value="simple",
|
||||
label="Type of Bar Plot"
|
||||
)
|
||||
with gr.Column():
|
||||
plot = gr.BarPlot(show_label=False).style(container=True)
|
||||
display.change(bar_plot_fn, inputs=display, outputs=plot)
|
||||
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "from bar_plot_demo import bar_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -2,6 +2,7 @@ import gradio as gr
|
||||
|
||||
from scatter_plot_demo import scatter_plot
|
||||
from line_plot_demo import line_plot
|
||||
from bar_plot_demo import bar_plot
|
||||
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
@ -10,6 +11,8 @@ with gr.Blocks() as demo:
|
||||
scatter_plot.render()
|
||||
with gr.TabItem("Line Plot"):
|
||||
line_plot.render()
|
||||
with gr.TabItem("Bar Plot"):
|
||||
bar_plot.render()
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -10,6 +10,7 @@ from gradio.components import (
|
||||
HTML,
|
||||
JSON,
|
||||
Audio,
|
||||
BarPlot,
|
||||
Button,
|
||||
Carousel,
|
||||
Chatbot,
|
||||
|
@ -4371,11 +4371,11 @@ class ScatterPlot(Plot):
|
||||
color_legend_position,
|
||||
size_legend_position,
|
||||
shape_legend_position,
|
||||
interactive,
|
||||
height,
|
||||
width,
|
||||
x_lim,
|
||||
y_lim,
|
||||
interactive,
|
||||
]
|
||||
if any(properties):
|
||||
if not isinstance(value, pd.DataFrame):
|
||||
@ -4591,8 +4591,8 @@ class LinePlot(Plot):
|
||||
caption: The (optional) caption to display below the plot.
|
||||
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
|
||||
label: The (optional) label to display on the top left corner of the plot.
|
||||
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
|
||||
show_label: Whether the label should be displayed.
|
||||
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
|
||||
visible: Whether the plot should be visible.
|
||||
elem_id: Unique id used for custom css targetting.
|
||||
"""
|
||||
@ -4863,6 +4863,320 @@ class LinePlot(Plot):
|
||||
return {"type": "altair", "plot": chart.to_json(), "chart": "line"}
|
||||
|
||||
|
||||
@document("change", "clear")
|
||||
class BarPlot(Plot):
|
||||
"""
|
||||
Create a bar plot.
|
||||
|
||||
Preprocessing: this component does *not* accept input.
|
||||
Postprocessing: expects a pandas dataframe with the data to plot.
|
||||
|
||||
Demos: native_plots, chicago-bikeshare-dashboard
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
value: pd.DataFrame | Callable | None = None,
|
||||
x: str | None = None,
|
||||
y: str | None = None,
|
||||
*,
|
||||
color: str | None = None,
|
||||
vertical: bool = True,
|
||||
group: str | None = None,
|
||||
title: str | None = None,
|
||||
tooltip: List[str] | str | None = None,
|
||||
x_title: str | None = None,
|
||||
y_title: str | None = None,
|
||||
color_legend_title: str | None = None,
|
||||
group_title: str | None = None,
|
||||
color_legend_position: str | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
y_lim: List[int] | None = None,
|
||||
caption: str | None = None,
|
||||
interactive: bool | None = True,
|
||||
label: str | None = None,
|
||||
show_label: bool = True,
|
||||
every: float | None = None,
|
||||
visible: bool = True,
|
||||
elem_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
value: The pandas dataframe containing the data to display in a scatter plot.
|
||||
x: Column corresponding to the x axis.
|
||||
y: Column corresponding to the y axis.
|
||||
color: The column to determine the bar color. Must be categorical (discrete values).
|
||||
vertical: If True, the bars will be displayed vertically. If False, the x and y axis will be switched, displaying the bars horizontally. Default is True.
|
||||
group: The column with which to split the overall plot into smaller subplots.
|
||||
title: The title to display on top of the chart.
|
||||
tooltip: The column (or list of columns) to display on the tooltip when a user hovers over a bar.
|
||||
x_title: The title given to the x axis. By default, uses the value of the x parameter.
|
||||
y_title: The title given to the y axis. By default, uses the value of the y parameter.
|
||||
color_legend_title: The title given to the color legend. By default, uses the value of color parameter.
|
||||
group_title: The label displayed on top of the subplot columns (or rows if vertical=True). Use an empty string to omit.
|
||||
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
|
||||
height: The height of the plot in pixels.
|
||||
width: The width of the plot in pixels.
|
||||
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
|
||||
caption: The (optional) caption to display below the plot.
|
||||
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
|
||||
label: The (optional) label to display on the top left corner of the plot.
|
||||
show_label: Whether the label should be displayed.
|
||||
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
|
||||
visible: Whether the plot should be visible.
|
||||
elem_id: Unique id used for custom css targetting.
|
||||
"""
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.color = color
|
||||
self.vertical = vertical
|
||||
self.group = group
|
||||
self.group_title = group_title
|
||||
self.tooltip = tooltip
|
||||
self.title = title
|
||||
self.x_title = x_title
|
||||
self.y_title = y_title
|
||||
self.color_legend_title = color_legend_title
|
||||
self.group_title = group_title
|
||||
self.color_legend_position = color_legend_position
|
||||
self.y_lim = y_lim
|
||||
self.caption = caption
|
||||
self.interactive_chart = interactive
|
||||
self.width = width
|
||||
self.height = height
|
||||
super().__init__(
|
||||
value=value,
|
||||
label=label,
|
||||
show_label=show_label,
|
||||
visible=visible,
|
||||
elem_id=elem_id,
|
||||
every=every,
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config["caption"] = self.caption
|
||||
return config
|
||||
|
||||
def get_block_name(self) -> str:
|
||||
return "plot"
|
||||
|
||||
@staticmethod
|
||||
def update(
|
||||
value: pd.DataFrame | Dict | Literal[_Keywords.NO_VALUE] = _Keywords.NO_VALUE,
|
||||
x: str | None = None,
|
||||
y: str | None = None,
|
||||
color: str | None = None,
|
||||
vertical: bool = True,
|
||||
group: str | None = None,
|
||||
title: str | None = None,
|
||||
tooltip: List[str] | str | None = None,
|
||||
x_title: str | None = None,
|
||||
y_title: str | None = None,
|
||||
color_legend_title: str | None = None,
|
||||
group_title: str | None = None,
|
||||
color_legend_position: str | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
y_lim: List[int] | None = None,
|
||||
caption: str | None = None,
|
||||
interactive: bool | None = True,
|
||||
label: str | None = None,
|
||||
show_label: bool = True,
|
||||
visible: bool = True,
|
||||
):
|
||||
"""Update an existing BarPlot component.
|
||||
|
||||
If updating any of the plot properties (color, size, etc) the value, x, and y parameters must be specified.
|
||||
|
||||
Parameters:
|
||||
value: The pandas dataframe containing the data to display in a scatter plot.
|
||||
x: Column corresponding to the x axis.
|
||||
y: Column corresponding to the y axis.
|
||||
color: The column to determine the bar color. Must be categorical (discrete values).
|
||||
vertical: If True, the bars will be displayed vertically. If False, the x and y axis will be switched, displaying the bars horizontally. Default is True.
|
||||
group: The column with which to split the overall plot into smaller subplots.
|
||||
title: The title to display on top of the chart.
|
||||
tooltip: The column (or list of columns) to display on the tooltip when a user hovers over a bar.
|
||||
x_title: The title given to the x axis. By default, uses the value of the x parameter.
|
||||
y_title: The title given to the y axis. By default, uses the value of the y parameter.
|
||||
color_legend_title: The title given to the color legend. By default, uses the value of color parameter.
|
||||
group_title: The label displayed on top of the subplot columns (or rows if vertical=True). Use an empty string to omit.
|
||||
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
|
||||
height: The height of the plot in pixels.
|
||||
width: The width of the plot in pixels.
|
||||
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
|
||||
caption: The (optional) caption to display below the plot.
|
||||
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
|
||||
label: The (optional) label to display on the top left corner of the plot.
|
||||
show_label: Whether the label should be displayed.
|
||||
visible: Whether the plot should be visible.
|
||||
"""
|
||||
properties = [
|
||||
x,
|
||||
y,
|
||||
color,
|
||||
vertical,
|
||||
group,
|
||||
title,
|
||||
tooltip,
|
||||
x_title,
|
||||
y_title,
|
||||
color_legend_title,
|
||||
group_title,
|
||||
color_legend_position,
|
||||
height,
|
||||
width,
|
||||
y_lim,
|
||||
interactive,
|
||||
]
|
||||
if any(properties):
|
||||
if not isinstance(value, pd.DataFrame):
|
||||
raise ValueError(
|
||||
"In order to update plot properties the value parameter "
|
||||
"must be provided, and it must be a Dataframe. Please pass a value "
|
||||
"parameter to gr.BarPlot.update."
|
||||
)
|
||||
if x is None or y is None:
|
||||
raise ValueError(
|
||||
"In order to update plot properties, the x and y axis data "
|
||||
"must be specified. Please pass valid values for x an y to "
|
||||
"gr.BarPlot.update."
|
||||
)
|
||||
chart = BarPlot.create_plot(value, *properties)
|
||||
value = {"type": "altair", "plot": chart.to_json(), "chart": "bar"}
|
||||
|
||||
updated_config = {
|
||||
"label": label,
|
||||
"show_label": show_label,
|
||||
"visible": visible,
|
||||
"value": value,
|
||||
"caption": caption,
|
||||
"__type__": "update",
|
||||
}
|
||||
return updated_config
|
||||
|
||||
@staticmethod
|
||||
def create_plot(
|
||||
value: pd.DataFrame,
|
||||
x: str,
|
||||
y: str,
|
||||
color: str | None = None,
|
||||
vertical: bool = True,
|
||||
group: str | None = None,
|
||||
title: str | None = None,
|
||||
tooltip: List[str] | str | None = None,
|
||||
x_title: str | None = None,
|
||||
y_title: str | None = None,
|
||||
color_legend_title: str | None = None,
|
||||
group_title: str | None = None,
|
||||
color_legend_position: str | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
y_lim: List[int] | None = None,
|
||||
interactive: bool | None = True,
|
||||
):
|
||||
"""Helper for creating the scatter plot."""
|
||||
interactive = True if interactive is None else interactive
|
||||
orientation = (
|
||||
dict(field=group, title=group_title if group_title is not None else group)
|
||||
if group
|
||||
else {}
|
||||
)
|
||||
|
||||
x_title = x_title or x
|
||||
y_title = y_title or y
|
||||
|
||||
# If horizontal, switch x and y
|
||||
if not vertical:
|
||||
y, x = x, y
|
||||
x = f"sum({x}):Q"
|
||||
y_title, x_title = x_title, y_title
|
||||
orientation = {"row": alt.Row(**orientation)} if orientation else {} # type: ignore
|
||||
x_lim = y_lim
|
||||
y_lim = None
|
||||
else:
|
||||
y = f"sum({y}):Q"
|
||||
x_lim = None
|
||||
orientation = {"column": alt.Column(**orientation)} if orientation else {} # type: ignore
|
||||
|
||||
encodings = dict(
|
||||
x=alt.X(
|
||||
x, # type: ignore
|
||||
title=x_title, # type: ignore
|
||||
scale=AltairPlot.create_scale(x_lim), # type: ignore
|
||||
),
|
||||
y=alt.Y(
|
||||
y, # type: ignore
|
||||
title=y_title, # type: ignore
|
||||
scale=AltairPlot.create_scale(y_lim), # type: ignore
|
||||
),
|
||||
**orientation,
|
||||
)
|
||||
properties = {}
|
||||
if title:
|
||||
properties["title"] = title
|
||||
if height:
|
||||
properties["height"] = height
|
||||
if width:
|
||||
properties["width"] = width
|
||||
|
||||
if color:
|
||||
domain = value[color].unique().tolist()
|
||||
range_ = list(range(len(domain)))
|
||||
encodings["color"] = {
|
||||
"field": color,
|
||||
"type": "nominal",
|
||||
"scale": {"domain": domain, "range": range_},
|
||||
"legend": AltairPlot.create_legend(
|
||||
position=color_legend_position, title=color_legend_title or color
|
||||
),
|
||||
}
|
||||
|
||||
if tooltip:
|
||||
encodings["tooltip"] = tooltip
|
||||
|
||||
chart = (
|
||||
alt.Chart(value) # type: ignore
|
||||
.mark_bar() # type: ignore
|
||||
.encode(**encodings)
|
||||
.properties(background="transparent", **properties)
|
||||
)
|
||||
if interactive:
|
||||
chart = chart.interactive()
|
||||
|
||||
return chart
|
||||
|
||||
def postprocess(self, y: pd.DataFrame | Dict | None) -> Dict[str, str] | None:
|
||||
# if None or update
|
||||
if y is None or isinstance(y, Dict):
|
||||
return y
|
||||
if self.x is None or self.y is None:
|
||||
raise ValueError("No value provided for required parameters `x` and `y`.")
|
||||
chart = self.create_plot(
|
||||
value=y,
|
||||
x=self.x,
|
||||
y=self.y,
|
||||
color=self.color,
|
||||
vertical=self.vertical,
|
||||
group=self.group,
|
||||
title=self.title,
|
||||
tooltip=self.tooltip,
|
||||
x_title=self.x_title,
|
||||
y_title=self.y_title,
|
||||
color_legend_title=self.color_legend_title,
|
||||
color_legend_position=self.color_legend_position,
|
||||
group_title=self.group_title,
|
||||
y_lim=self.y_lim,
|
||||
interactive=self.interactive_chart,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
)
|
||||
|
||||
return {"type": "altair", "plot": chart.to_json(), "chart": "bar"}
|
||||
|
||||
|
||||
@document("change")
|
||||
class Markdown(IOComponent, Changeable, SimpleSerializable):
|
||||
"""
|
||||
|
@ -418,7 +418,7 @@ class TestComponentsInBlocks:
|
||||
io_components = [
|
||||
c()
|
||||
for c in io_components
|
||||
if c not in [gr.State, gr.Button, gr.ScatterPlot, gr.LinePlot]
|
||||
if c not in [gr.State, gr.Button, gr.ScatterPlot, gr.LinePlot, gr.BarPlot]
|
||||
]
|
||||
with gr.Blocks() as demo:
|
||||
for component in io_components:
|
||||
|
@ -2005,6 +2005,13 @@ def test_dataset_calls_as_example(*mocks):
|
||||
|
||||
cars = vega_datasets.data.cars()
|
||||
stocks = vega_datasets.data.stocks()
|
||||
barley = vega_datasets.data.barley()
|
||||
simple = pd.DataFrame(
|
||||
{
|
||||
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
|
||||
"b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestScatterPlot:
|
||||
@ -2348,3 +2355,146 @@ class TestLinePlot:
|
||||
)
|
||||
assert isinstance(plot.value, dict)
|
||||
assert isinstance(plot.value["plot"], str)
|
||||
|
||||
|
||||
class TestBarPlot:
|
||||
def test_get_config(self):
|
||||
assert gr.BarPlot().get_config() == {
|
||||
"caption": None,
|
||||
"elem_id": None,
|
||||
"interactive": None,
|
||||
"label": None,
|
||||
"name": "plot",
|
||||
"root_url": None,
|
||||
"show_label": True,
|
||||
"style": {},
|
||||
"value": None,
|
||||
"visible": True,
|
||||
}
|
||||
|
||||
def test_no_color(self):
|
||||
plot = gr.BarPlot(
|
||||
x="a",
|
||||
y="b",
|
||||
tooltip=["a", "b"],
|
||||
title="Made Up Bar Plot",
|
||||
x_title="Variable A",
|
||||
)
|
||||
output = plot.postprocess(simple)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
assert output["chart"] == "bar"
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["x"]["field"] == "a"
|
||||
assert config["encoding"]["x"]["title"] == "Variable A"
|
||||
assert config["encoding"]["y"]["field"] == "b"
|
||||
assert config["encoding"]["y"]["title"] == "b"
|
||||
|
||||
assert config["title"] == "Made Up Bar Plot"
|
||||
assert "height" not in config
|
||||
assert "width" not in config
|
||||
|
||||
def test_height_width(self):
|
||||
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
|
||||
output = plot.postprocess(simple)
|
||||
assert sorted(list(output.keys())) == ["chart", "plot", "type"]
|
||||
config = json.loads(output["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
|
||||
output = gr.BarPlot.update(simple, x="a", y="b", height=100, width=200)
|
||||
config = json.loads(output["value"]["plot"])
|
||||
assert config["height"] == 100
|
||||
assert config["width"] == 200
|
||||
|
||||
def test_ylim(self):
|
||||
plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100])
|
||||
output = plot.postprocess(simple)
|
||||
config = json.loads(output["plot"])
|
||||
assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]}
|
||||
|
||||
def test_horizontal(self):
|
||||
output = gr.BarPlot.update(
|
||||
simple,
|
||||
x="a",
|
||||
y="b",
|
||||
x_title="Variable A",
|
||||
y_title="Variable B",
|
||||
title="Simple Bar Plot with made up data",
|
||||
tooltip=["a", "b"],
|
||||
vertical=False,
|
||||
y_lim=[20, 100],
|
||||
)
|
||||
assert output["value"]["chart"] == "bar"
|
||||
config = json.loads(output["value"]["plot"])
|
||||
assert config["encoding"]["x"]["field"] == "b"
|
||||
assert config["encoding"]["x"]["scale"] == {"domain": [20, 100]}
|
||||
assert config["encoding"]["x"]["title"] == "Variable B"
|
||||
|
||||
assert config["encoding"]["y"]["field"] == "a"
|
||||
assert config["encoding"]["y"]["title"] == "Variable A"
|
||||
|
||||
def test_stack_via_color(self):
|
||||
output = gr.BarPlot.update(
|
||||
barley,
|
||||
x="variety",
|
||||
y="yield",
|
||||
color="site",
|
||||
title="Barley Yield Data",
|
||||
color_legend_title="Site",
|
||||
color_legend_position="bottom",
|
||||
)
|
||||
config = json.loads(output["value"]["plot"])
|
||||
assert config["encoding"]["color"]["field"] == "site"
|
||||
assert config["encoding"]["color"]["legend"] == {
|
||||
"title": "Site",
|
||||
"orient": "bottom",
|
||||
}
|
||||
assert config["encoding"]["color"]["scale"] == {
|
||||
"domain": [
|
||||
"University Farm",
|
||||
"Waseca",
|
||||
"Morris",
|
||||
"Crookston",
|
||||
"Grand Rapids",
|
||||
"Duluth",
|
||||
],
|
||||
"range": [0, 1, 2, 3, 4, 5],
|
||||
}
|
||||
|
||||
def test_group(self):
|
||||
output = gr.BarPlot.update(
|
||||
barley,
|
||||
x="year",
|
||||
y="yield",
|
||||
color="year",
|
||||
group="site",
|
||||
title="Barley Yield by Year and Site",
|
||||
group_title="",
|
||||
tooltip=["yield", "site", "year"],
|
||||
)
|
||||
config = json.loads(output["value"]["plot"])
|
||||
assert config["encoding"]["column"] == {"field": "site", "title": ""}
|
||||
|
||||
def test_group_horizontal(self):
|
||||
output = gr.BarPlot.update(
|
||||
barley,
|
||||
x="year",
|
||||
y="yield",
|
||||
color="year",
|
||||
group="site",
|
||||
title="Barley Yield by Year and Site",
|
||||
group_title="Site Title",
|
||||
tooltip=["yield", "site", "year"],
|
||||
vertical=False,
|
||||
)
|
||||
config = json.loads(output["value"]["plot"])
|
||||
assert config["encoding"]["row"] == {"field": "site", "title": "Site Title"}
|
||||
|
||||
def test_barplot_accepts_fn_as_value(self):
|
||||
plot = gr.BarPlot(
|
||||
value=lambda: barley.sample(frac=0.1, replace=False),
|
||||
x="year",
|
||||
y="yield",
|
||||
)
|
||||
assert isinstance(plot.value, dict)
|
||||
assert isinstance(plot.value["plot"], str)
|
||||
|
@ -58,6 +58,14 @@
|
||||
);
|
||||
}
|
||||
});
|
||||
break;
|
||||
case "bar":
|
||||
if (spec.encoding.color) {
|
||||
spec.encoding.color.scale.range = spec.encoding.color.scale.range.map(
|
||||
(e, i) => get_color(i)
|
||||
);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -26,7 +26,8 @@ export function create_config(darkmode: boolean): VegaConfig {
|
||||
title: {
|
||||
color: darkmode ? dark : light,
|
||||
font: "sans-serif",
|
||||
fontWeight: "normal"
|
||||
fontWeight: "normal",
|
||||
anchor: "middle"
|
||||
}
|
||||
};
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user