From a238af4d688c4e030e37c2ef01d5c80d6d940912 Mon Sep 17 00:00:00 2001 From: aliabid94 Date: Mon, 22 Jul 2024 09:52:48 -0700 Subject: [PATCH] Refactor plots to drop `altair` and use `vega.js` directly (#8807) * changes * add changeset * changes * changes * changes * add changeset * changes * add changeset * changes * add changeset * add changeset * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * add changeset * changes * changes * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid * Update gradio/blocks.py Co-authored-by: Abubakar Abid * changes * changes * changes * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid * Update gradio/components/native_plot.py Co-authored-by: Abubakar Abid * changes * changes * changes --------- Co-authored-by: Ali Abid Co-authored-by: gradio-pr-bot Co-authored-by: Abubakar Abid --- .changeset/tangy-beds-guess.md | 8 + demo/blocks_xray/run.ipynb | 2 +- demo/blocks_xray/run.py | 2 +- demo/native_plots/bar_plot_demo.py | 176 +++--- demo/native_plots/data.py | 20 + demo/native_plots/line_plot_demo.py | 137 +++-- demo/native_plots/run.ipynb | 2 +- demo/native_plots/run.py | 14 +- demo/native_plots/scatter_plot_demo.py | 97 ++-- gradio/blocks.py | 4 + gradio/chat_interface.py | 3 + .../cli/commands/components/_create_utils.py | 6 +- gradio/components/__init__.py | 4 +- gradio/components/native_plot.py | 263 +++++++++ gradio/data_classes.py | 1 + gradio/interface.py | 3 + gradio/monitoring_dashboard.py | 89 +-- gradio/utils.py | 3 +- js/app/package.json | 1 + js/app/src/Embed.svelte | 12 +- js/app/src/Index.svelte | 2 + js/datetime/Index.svelte | 13 +- js/nativeplot/CHANGELOG.md | 1 + js/nativeplot/Example.svelte | 11 + js/nativeplot/Index.svelte | 505 ++++++++++++++++++ js/nativeplot/package.json | 28 + package.json | 1 + pnpm-lock.yaml | 47 +- requirements.txt | 1 - test/components/plot_data.py | 6 +- test/components/test_bar_plot.py | 118 ---- test/components/test_line_plot.py | 112 ---- test/components/test_native_plots.py | 32 ++ test/components/test_scatter_plot.py | 168 ------ 34 files changed, 1204 insertions(+), 688 deletions(-) create mode 100644 .changeset/tangy-beds-guess.md create mode 100644 demo/native_plots/data.py create mode 100644 gradio/components/native_plot.py create mode 100644 js/nativeplot/CHANGELOG.md create mode 100644 js/nativeplot/Example.svelte create mode 100644 js/nativeplot/Index.svelte create mode 100644 js/nativeplot/package.json delete mode 100644 test/components/test_bar_plot.py delete mode 100644 test/components/test_line_plot.py create mode 100644 test/components/test_native_plots.py delete mode 100644 test/components/test_scatter_plot.py diff --git a/.changeset/tangy-beds-guess.md b/.changeset/tangy-beds-guess.md new file mode 100644 index 0000000000..b83d01e644 --- /dev/null +++ b/.changeset/tangy-beds-guess.md @@ -0,0 +1,8 @@ +--- +"@gradio/app": minor +"@gradio/datetime": minor +"@gradio/nativeplot": minor +"gradio": minor +--- + +feat:Refactor plots to drop `altair` and use `vega.js` directly diff --git a/demo/blocks_xray/run.ipynb b/demo/blocks_xray/run.ipynb index 6754b73faf..c8cf25affa 100644 --- a/demo/blocks_xray/run.ipynb +++ b/demo/blocks_xray/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_xray"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "disease_values = [0.25, 0.5, 0.75]\n", "\n", "def xray_model(diseases, img):\n", " return [{disease: disease_values[idx] for idx,disease in enumerate(diseases)}]\n", "\n", "\n", "def ct_model(diseases, img):\n", " return [{disease: 0.1 for disease in diseases}]\n", "\n", "with gr.Blocks(fill_width=True) as demo:\n", " gr.Markdown(\n", " \"\"\"\n", "# Detect Disease From Scan\n", "With this model you can lorem ipsum\n", "- ipsum 1\n", "- ipsum 2\n", "\"\"\"\n", " )\n", " gr.DuplicateButton()\n", " disease = gr.CheckboxGroup(\n", " info=\"Select the diseases you want to scan for.\",\n", " choices=[\"Covid\", \"Malaria\", \"Lung Cancer\"], label=\"Disease to Scan For\"\n", " )\n", " slider = gr.Slider(0, 100)\n", "\n", " with gr.Tab(\"X-ray\") as x_tab:\n", " with gr.Row():\n", " xray_scan = gr.Image()\n", " xray_results = gr.JSON()\n", " xray_run = gr.Button(\"Run\")\n", " xray_run.click(\n", " xray_model,\n", " inputs=[disease, xray_scan],\n", " outputs=xray_results,\n", " api_name=\"xray_model\"\n", " )\n", "\n", " with gr.Tab(\"CT Scan\"):\n", " with gr.Row():\n", " ct_scan = gr.Image()\n", " ct_results = gr.JSON()\n", " ct_run = gr.Button(\"Run\")\n", " ct_run.click(\n", " ct_model,\n", " inputs=[disease, ct_scan],\n", " outputs=ct_results,\n", " api_name=\"ct_model\"\n", " )\n", "\n", " upload_btn = gr.Button(\"Upload Results\", variant=\"primary\")\n", " upload_btn.click(\n", " lambda ct, xr: None,\n", " inputs=[ct_results, xray_results],\n", " outputs=[],\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/blocks_xray/run.py b/demo/blocks_xray/run.py index e8e8e8167f..f1cb510263 100644 --- a/demo/blocks_xray/run.py +++ b/demo/blocks_xray/run.py @@ -10,7 +10,7 @@ def xray_model(diseases, img): def ct_model(diseases, img): return [{disease: 0.1 for disease in diseases}] -with gr.Blocks() as demo: +with gr.Blocks(fill_width=True) as demo: gr.Markdown( """ # Detect Disease From Scan diff --git a/demo/native_plots/bar_plot_demo.py b/demo/native_plots/bar_plot_demo.py index 302765788a..0b538a3cd0 100644 --- a/demo/native_plots/bar_plot_demo.py +++ b/demo/native_plots/bar_plot_demo.py @@ -1,111 +1,77 @@ import gradio as gr -import pandas as pd +import numpy as np +from data import temp_sensor_data, food_rating_data -from vega_datasets import data +with gr.Blocks() as bar_plots: + with gr.Row(): + start = gr.DateTime("2021-01-01 00:00:00", label="Start") + end = gr.DateTime("2021-01-05 00:00:00", label="End") + apply_btn = gr.Button("Apply", scale=0) + with gr.Row(): + group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by") + aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation") -barley = data.barley() -simple = pd.DataFrame({ - 'a': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'], - 'b': [28, 55, 43, 91, 81, 53, 19, 87, 52] -}) - -def bar_plot_fn(display): - if display == "simple": - return gr.BarPlot( - simple, - x="a", - y="b", - color=None, - group=None, - title="Simple Bar Plot with made up data", - tooltip=['a', 'b'], - y_lim=[20, 100], - x_title=None, - y_title=None, - vertical=True, - ) - elif display == "stacked": - return gr.BarPlot( - barley, - x="variety", - y="yield", - color="site", - group=None, - title="Barley Yield Data", - tooltip=['variety', 'site'], - y_lim=None, - x_title=None, - y_title=None, - vertical=True, - ) - elif display == "grouped": - return gr.BarPlot( - barley.astype({"year": str}), - x="year", - y="yield", - color="year", - group="site", - title="Barley Yield by Year and Site", - tooltip=["yield", "site", "year"], - y_lim=None, - x_title=None, - y_title=None, - vertical=True, - ) - elif display == "simple-horizontal": - return gr.BarPlot( - simple, - x="a", - y="b", - color=None, - group=None, - title="Simple Bar Plot with made up data", - tooltip=['a', 'b'], - y_lim=[20, 100], - x_title="Variable A", - y_title="Variable B", - vertical=False, - ) - elif display == "stacked-horizontal": - return gr.BarPlot( - barley, - x="variety", - y="yield", - color="site", - group=None, - title="Barley Yield Data", - tooltip=['variety', 'site'], - y_lim=None, - x_title=None, - y_title=None, - vertical=False, - ) - elif display == "grouped-horizontal": - return gr.BarPlot( - barley.astype({"year": str}), - x="year", - y="yield", - color="year", - group="site", - title="Barley Yield by Year and Site", - group_title="", - tooltip=["yield", "site", "year"], - y_lim=None, - x_title=None, - y_title=None, - vertical=False - ) - - -with gr.Blocks() as bar_plot: - display = gr.Dropdown( - choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"], - value="simple", - label="Type of Bar Plot" + temp_by_time = gr.BarPlot( + temp_sensor_data, + x="time", + y="temperature", ) - plot = gr.BarPlot(show_label=False) - display.change(bar_plot_fn, inputs=display, outputs=plot) - bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot) + temp_by_time_location = gr.BarPlot( + temp_sensor_data, + x="time", + y="temperature", + color="location", + ) + + time_graphs = [temp_by_time, temp_by_time_location] + group_by.change( + lambda group: [gr.BarPlot(x_bin=None if group == "None" else group)] * len(time_graphs), + group_by, + time_graphs + ) + aggregate.change( + lambda aggregate: [gr.BarPlot(y_aggregate=aggregate)] * len(time_graphs), + aggregate, + time_graphs + ) + + + def rescale(select: gr.SelectData): + return select.index + rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end]) + + for trigger in [apply_btn.click, rescale_evt.then]: + trigger( + lambda start, end: [gr.BarPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs + ) + + with gr.Row(): + price_by_cuisine = gr.BarPlot( + food_rating_data, + x="cuisine", + y="price", + ) + with gr.Column(scale=0): + gr.Button("Sort $ > $$$").click(lambda: gr.BarPlot(sort="y"), None, price_by_cuisine) + gr.Button("Sort $$$ > $").click(lambda: gr.BarPlot(sort="-y"), None, price_by_cuisine) + gr.Button("Sort A > Z").click(lambda: gr.BarPlot(sort=["Chinese", "Italian", "Mexican"]), None, price_by_cuisine) + + with gr.Row(): + price_by_rating = gr.BarPlot( + food_rating_data, + x="rating", + y="price", + x_bin=1, + ) + price_by_rating_color = gr.BarPlot( + food_rating_data, + x="rating", + y="price", + color="cuisine", + x_bin=1, + color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"}, + ) + if __name__ == "__main__": - bar_plot.launch() \ No newline at end of file + bar_plots.launch() diff --git a/demo/native_plots/data.py b/demo/native_plots/data.py new file mode 100644 index 0000000000..cafc6046cc --- /dev/null +++ b/demo/native_plots/data.py @@ -0,0 +1,20 @@ +import pandas as pd +from random import randint, choice, random + +temp_sensor_data = pd.DataFrame( + { + "time": pd.date_range("2021-01-01", end="2021-01-05", periods=200), + "temperature": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)], + "humidity": [randint(50 + 10 * (i % 2), 65 + 15 * (i % 2)) for i in range(200)], + "location": ["indoor", "outdoor"] * 100, + } +) + +food_rating_data = pd.DataFrame( + { + "cuisine": [["Italian", "Mexican", "Chinese"][i % 3] for i in range(100)], + "rating": [random() * 4 + 0.5 * (i % 3) for i in range(100)], + "price": [randint(10, 50) + 4 * (i % 3) for i in range(100)], + "wait": [random() for i in range(100)], + } +) \ No newline at end of file diff --git a/demo/native_plots/line_plot_demo.py b/demo/native_plots/line_plot_demo.py index a6c5591708..a08d6fcee9 100644 --- a/demo/native_plots/line_plot_demo.py +++ b/demo/native_plots/line_plot_demo.py @@ -1,82 +1,69 @@ import gradio as gr -from vega_datasets import data +import numpy as np +from data import temp_sensor_data, food_rating_data -stocks = data.stocks() -gapminder = data.gapminder() -gapminder = gapminder.loc[ - gapminder.country.isin(["Argentina", "Australia", "Afghanistan"]) -] -climate = data.climate() -seattle_weather = data.seattle_weather() +with gr.Blocks() as line_plots: + with gr.Row(): + start = gr.DateTime("2021-01-01 00:00:00", label="Start") + end = gr.DateTime("2021-01-05 00:00:00", label="End") + apply_btn = gr.Button("Apply", scale=0) + with gr.Row(): + group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by") + aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation") - -def line_plot_fn(dataset): - if dataset == "stocks": - return gr.LinePlot( - stocks, - x="date", - y="price", - color="symbol", - x_lim=None, - y_lim=None, - stroke_dash=None, - tooltip=['date', 'price', 'symbol'], - overlay_point=False, - title="Stock Prices", - stroke_dash_legend_title=None, - ) - elif dataset == "climate": - return gr.LinePlot( - climate, - x="DATE", - y="HLY-TEMP-NORMAL", - color=None, - x_lim=None, - y_lim=[250, 500], - stroke_dash=None, - tooltip=['DATE', 'HLY-TEMP-NORMAL'], - overlay_point=False, - title="Climate", - stroke_dash_legend_title=None, - ) - elif dataset == "seattle_weather": - return gr.LinePlot( - seattle_weather, - x="date", - y="temp_min", - color=None, - x_lim=None, - y_lim=None, - stroke_dash=None, - tooltip=["weather", "date"], - overlay_point=True, - title="Seattle Weather", - stroke_dash_legend_title=None, - ) - elif dataset == "gapminder": - return gr.LinePlot( - gapminder, - x="year", - y="life_expect", - color="country", - x_lim=[1950, 2010], - y_lim=None, - stroke_dash="cluster", - tooltip=['country', 'life_expect'], - overlay_point=False, - title="Life expectancy for countries", - ) - - -with gr.Blocks() as line_plot: - dataset = gr.Dropdown( - choices=["stocks", "climate", "seattle_weather", "gapminder"], - value="stocks", + temp_by_time = gr.LinePlot( + temp_sensor_data, + x="time", + y="temperature", ) - plot = gr.LinePlot() - dataset.change(line_plot_fn, inputs=dataset, outputs=plot) - line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot) + temp_by_time_location = gr.LinePlot( + temp_sensor_data, + x="time", + y="temperature", + color="location", + ) + + time_graphs = [temp_by_time, temp_by_time_location] + group_by.change( + lambda group: [gr.LinePlot(x_bin=None if group == "None" else group)] * len(time_graphs), + group_by, + time_graphs + ) + aggregate.change( + lambda aggregate: [gr.LinePlot(y_aggregate=aggregate)] * len(time_graphs), + aggregate, + time_graphs + ) + + + def rescale(select: gr.SelectData): + return select.index + rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end]) + + for trigger in [apply_btn.click, rescale_evt.then]: + trigger( + lambda start, end: [gr.LinePlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs + ) + + price_by_cuisine = gr.LinePlot( + food_rating_data, + x="cuisine", + y="price", + ) + with gr.Row(): + price_by_rating = gr.LinePlot( + food_rating_data, + x="rating", + y="price", + ) + price_by_rating_color = gr.LinePlot( + food_rating_data, + x="rating", + y="price", + color="cuisine", + color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"}, + ) if __name__ == "__main__": - line_plot.launch() + line_plots.launch() diff --git a/demo/native_plots/run.ipynb b/demo/native_plots/run.ipynb index f79c3ad985..a96c6b0b7f 100644 --- a/demo/native_plots/run.ipynb +++ b/demo/native_plots/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plot\n", "from line_plot_demo import line_plot\n", "from bar_plot_demo import bar_plot\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plot.render()\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plot.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plot.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: native_plots"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/bar_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/data.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/line_plot_demo.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/native_plots/scatter_plot_demo.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "from scatter_plot_demo import scatter_plots\n", "from line_plot_demo import line_plots\n", "from bar_plot_demo import bar_plots\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Tabs():\n", " with gr.TabItem(\"Line Plot\"):\n", " line_plots.render()\n", " with gr.TabItem(\"Scatter Plot\"):\n", " scatter_plots.render()\n", " with gr.TabItem(\"Bar Plot\"):\n", " bar_plots.render()\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/native_plots/run.py b/demo/native_plots/run.py index 3ed9ed8309..d4e771ab04 100644 --- a/demo/native_plots/run.py +++ b/demo/native_plots/run.py @@ -1,18 +1,18 @@ import gradio as gr -from scatter_plot_demo import scatter_plot -from line_plot_demo import line_plot -from bar_plot_demo import bar_plot +from scatter_plot_demo import scatter_plots +from line_plot_demo import line_plots +from bar_plot_demo import bar_plots with gr.Blocks() as demo: with gr.Tabs(): - with gr.TabItem("Scatter Plot"): - scatter_plot.render() with gr.TabItem("Line Plot"): - line_plot.render() + line_plots.render() + with gr.TabItem("Scatter Plot"): + scatter_plots.render() with gr.TabItem("Bar Plot"): - bar_plot.render() + bar_plots.render() if __name__ == "__main__": demo.launch() diff --git a/demo/native_plots/scatter_plot_demo.py b/demo/native_plots/scatter_plot_demo.py index 54d9210552..342de33b57 100644 --- a/demo/native_plots/scatter_plot_demo.py +++ b/demo/native_plots/scatter_plot_demo.py @@ -1,46 +1,71 @@ import gradio as gr +import numpy as np +from data import temp_sensor_data, food_rating_data -from vega_datasets import data +with gr.Blocks() as scatter_plots: + with gr.Row(): + start = gr.DateTime("2021-01-01 00:00:00", label="Start") + end = gr.DateTime("2021-01-05 00:00:00", label="End") + apply_btn = gr.Button("Apply", scale=0) + with gr.Row(): + group_by = gr.Radio(["None", "30m", "1h", "4h", "1d"], value="None", label="Group by") + aggregate = gr.Radio(["sum", "mean", "median", "min", "max"], value="sum", label="Aggregation") -cars = data.cars() -iris = data.iris() + temp_by_time = gr.ScatterPlot( + temp_sensor_data, + x="time", + y="temperature", + ) + temp_by_time_location = gr.ScatterPlot( + temp_sensor_data, + x="time", + y="temperature", + color="location", + ) + + time_graphs = [temp_by_time, temp_by_time_location] + group_by.change( + lambda group: [gr.ScatterPlot(x_bin=None if group == "None" else group)] * len(time_graphs), + group_by, + time_graphs + ) + aggregate.change( + lambda aggregate: [gr.ScatterPlot(y_aggregate=aggregate)] * len(time_graphs), + aggregate, + time_graphs + ) -def scatter_plot_fn(dataset): - if dataset == "iris": - return gr.ScatterPlot( - value=iris, - x="petalWidth", - y="petalLength", - color=None, - title="Iris Dataset", - x_title="Petal Width", - y_title="Petal Length", - tooltip=["petalWidth", "petalLength", "species"], - caption="", - height=600, - width=600, + # def rescale(select: gr.SelectData): + # return select.index + # rescale_evt = gr.on([plot.select for plot in time_graphs], rescale, None, [start, end]) + + # for trigger in [apply_btn.click, rescale_evt.then]: + # trigger( + # lambda start, end: [gr.ScatterPlot(x_lim=[start, end])] * len(time_graphs), [start, end], time_graphs + # ) + + price_by_cuisine = gr.ScatterPlot( + food_rating_data, + x="cuisine", + y="price", + ) + with gr.Row(): + price_by_rating = gr.ScatterPlot( + food_rating_data, + x="rating", + y="price", + color="wait", + show_actions_button=True, ) - else: - return gr.ScatterPlot( - value=cars, - x="Horsepower", - y="Miles_per_Gallon", - color="Origin", - tooltip="Name", - title="Car Data", - y_title="Miles per Gallon", - caption="MPG vs Horsepower of various cars", - height=None, - width=None, + price_by_rating_color = gr.ScatterPlot( + food_rating_data, + x="rating", + y="price", + color="cuisine", + # color_map={"Italian": "red", "Mexican": "green", "Chinese": "blue"}, ) -with gr.Blocks() as scatter_plot: - dataset = gr.Dropdown(choices=["cars", "iris"], value="cars") - plot = gr.ScatterPlot(show_label=False) - dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot) - scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot) - if __name__ == "__main__": - scatter_plot.launch() + scatter_plots.launch() diff --git a/gradio/blocks.py b/gradio/blocks.py index 860fd02436..9bf2eee3e8 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -942,6 +942,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): js: str | None = None, head: str | None = None, fill_height: bool = False, + fill_width: bool = False, delete_cache: tuple[int, int] | None = None, **kwargs, ): @@ -955,6 +956,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): js: Custom js as a string or path to a js file. The custom js should be in the form of a single js function. This function will automatically be executed when the page loads. For more flexibility, use the head parameter to insert js inside + +{#if title} + {title} +{:else} + {x} x {y} +{/if} diff --git a/js/nativeplot/Index.svelte b/js/nativeplot/Index.svelte new file mode 100644 index 0000000000..25be8f1d71 --- /dev/null +++ b/js/nativeplot/Index.svelte @@ -0,0 +1,505 @@ + + + + {#if loading_status} + gradio.dispatch("clear_status", loading_status)} + /> + {/if} + {label} +
+ {#if value} + {#if caption} +

{caption}

+ {/if} + {:else} + + {/if} +
+ + diff --git a/js/nativeplot/package.json b/js/nativeplot/package.json new file mode 100644 index 0000000000..776a5c2c20 --- /dev/null +++ b/js/nativeplot/package.json @@ -0,0 +1,28 @@ +{ + "name": "@gradio/nativeplot", + "version": "0.0.1", + "description": "Gradio UI packages", + "type": "module", + "author": "", + "license": "ISC", + "private": false, + "main_changeset": true, + "exports": { + ".": "./Index.svelte", + "./example": "./Example.svelte", + "./package.json": "./package.json" + }, + "dependencies": { + "@gradio/atoms": "workspace:^", + "@gradio/icons": "workspace:^", + "@gradio/statustracker": "workspace:^", + "@gradio/utils": "workspace:^", + "@gradio/theme": "workspace:^", + "vega": "^5.23.0", + "vega-embed": "^6.25.0", + "vega-lite": "^5.12.0" + }, + "devDependencies": { + "@gradio/preview": "workspace:^" + } +} diff --git a/package.json b/package.json index 16daf359c2..5d75fb45e5 100644 --- a/package.json +++ b/package.json @@ -122,6 +122,7 @@ "@gradio/markdown": "workspace:^", "@gradio/model3d": "workspace:^", "@gradio/multimodaltextbox": "workspace:^", + "@gradio/nativeplot": "workspace:^", "@gradio/number": "workspace:^", "@gradio/paramviewer": "workspace:^", "@gradio/plot": "workspace:^", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7b69d1bec7..68fb34f7df 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -270,6 +270,9 @@ importers: '@gradio/multimodaltextbox': specifier: workspace:^ version: link:js/multimodaltextbox + '@gradio/nativeplot': + specifier: workspace:^ + version: link:js/nativeplot '@gradio/number': specifier: workspace:^ version: link:js/number @@ -625,6 +628,9 @@ importers: '@gradio/multimodaltextbox': specifier: workspace:^ version: link:../multimodaltextbox + '@gradio/nativeplot': + specifier: workspace:^ + version: link:../nativeplot '@gradio/number': specifier: workspace:^ version: link:../number @@ -1549,6 +1555,37 @@ importers: specifier: workspace:^ version: link:../preview + js/nativeplot: + dependencies: + '@gradio/atoms': + specifier: workspace:^ + version: link:../atoms + '@gradio/icons': + specifier: workspace:^ + version: link:../icons + '@gradio/statustracker': + specifier: workspace:^ + version: link:../statustracker + '@gradio/theme': + specifier: workspace:^ + version: link:../theme + '@gradio/utils': + specifier: workspace:^ + version: link:../utils + vega: + specifier: ^5.23.0 + version: 5.23.0 + vega-embed: + specifier: ^6.25.0 + version: 6.25.0(vega-lite@5.12.0(vega@5.23.0))(vega@5.23.0) + vega-lite: + specifier: ^5.12.0 + version: 5.12.0(vega@5.23.0) + devDependencies: + '@gradio/preview': + specifier: workspace:^ + version: link:../preview + js/number: dependencies: '@gradio/atoms': @@ -6025,10 +6062,6 @@ packages: engines: {node: '>=12'} hasBin: true - escalade@3.1.1: - resolution: {integrity: sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==} - engines: {node: '>=6'} - escalade@3.1.2: resolution: {integrity: sha512-ErCHMCae19vR8vQGe50xIsVomy19rg6gFu3+r3jkEO46suLMWBksvVyoGgQV+jOfl84ZSOSlmv6Gxa89PmTGmA==} engines: {node: '>=6'} @@ -14422,8 +14455,6 @@ snapshots: '@esbuild/win32-ia32': 0.21.0 '@esbuild/win32-x64': 0.21.0 - escalade@3.1.1: {} - escalade@3.1.2: {} escape-html@1.0.3: {} @@ -18217,7 +18248,7 @@ snapshots: d3-timer: 3.0.1 vega-dataflow: 5.7.5 vega-format: 1.1.1 - vega-functions: 5.13.2 + vega-functions: 5.14.0 vega-runtime: 6.1.4 vega-scenegraph: 4.10.2 vega-util: 1.17.2 @@ -18565,7 +18596,7 @@ snapshots: yargs@17.7.2: dependencies: cliui: 8.0.1 - escalade: 3.1.1 + escalade: 3.1.2 get-caller-file: 2.0.5 require-directory: 2.1.1 string-width: 4.2.3 diff --git a/requirements.txt b/requirements.txt index 126e913f17..50c3a8b4c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ aiofiles>=22.0,<24.0 -altair>=5.0,<6.0 anyio>=3.0,<5.0 fastapi ffmpy diff --git a/test/components/plot_data.py b/test/components/plot_data.py index db3ab2a7a0..fdfae60af6 100644 --- a/test/components/plot_data.py +++ b/test/components/plot_data.py @@ -1,12 +1,14 @@ +from datetime import datetime, timedelta + import pandas as pd import vega_datasets -cars = vega_datasets.data.cars() -stocks = vega_datasets.data.stocks() barley = vega_datasets.data.barley() simple = pd.DataFrame( { "a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"], "b": [28, 55, 43, 91, 81, 53, 19, 87, 52], + "c": [0] * 9, + "date": [datetime.now() - timedelta(days=i) for i in range(9)], } ) diff --git a/test/components/test_bar_plot.py b/test/components/test_bar_plot.py deleted file mode 100644 index 0f21cf0fad..0000000000 --- a/test/components/test_bar_plot.py +++ /dev/null @@ -1,118 +0,0 @@ -import json -from unittest.mock import MagicMock, patch - -import gradio as gr - -from .plot_data import barley, simple - - -class TestBarPlot: - @patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")}) - def test_get_config(self): - assert gr.BarPlot().get_config() == { - "caption": None, - "elem_id": None, - "elem_classes": [], - "interactive": None, - "label": None, - "name": "barplot", - "bokeh_version": "3.0.3", - "show_actions_button": False, - "proxy_url": None, - "show_label": False, - "container": True, - "min_width": 160, - "scale": None, - "value": None, - "visible": True, - "x": None, - "y": None, - "color": None, - "vertical": True, - "group": None, - "title": None, - "tooltip": None, - "x_title": None, - "y_title": None, - "color_legend_title": None, - "group_title": None, - "color_legend_position": None, - "height": None, - "width": None, - "y_lim": None, - "x_label_angle": None, - "y_label_angle": None, - "sort": None, - "_selectable": False, - "key": None, - } - - def test_no_color(self): - plot = gr.BarPlot( - x="a", - y="b", - tooltip=["a", "b"], - title="Made Up Bar Plot", - x_title="Variable A", - sort="x", - ) - assert (output := plot.postprocess(simple)) - output = output.model_dump() - assert sorted(output.keys()) == ["chart", "plot", "type"] - assert output["chart"] == "bar" - config = json.loads(output["plot"]) - assert config["encoding"]["x"]["sort"] == "x" - assert config["encoding"]["x"]["field"] == "a" - assert config["encoding"]["x"]["title"] == "Variable A" - assert config["encoding"]["y"]["field"] == "b" - assert config["encoding"]["y"]["title"] == "b" - - assert config["title"] == "Made Up Bar Plot" - assert "height" not in config - assert "width" not in config - - def test_height_width(self): - plot = gr.BarPlot(x="a", y="b", height=100, width=200) - assert (output := plot.postprocess(simple)) - output = output.model_dump() - assert sorted(output.keys()) == ["chart", "plot", "type"] - config = json.loads(output["plot"]) - assert config["height"] == 100 - assert config["width"] == 200 - - def test_ylim(self): - plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100]) - assert (output := plot.postprocess(simple)) - output = output.model_dump() - config = json.loads(output["plot"]) - assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]} - - def test_horizontal(self): - output = gr.BarPlot( - simple, - x="a", - y="b", - x_title="Variable A", - y_title="Variable B", - title="Simple Bar Plot with made up data", - tooltip=["a", "b"], - vertical=False, - y_lim=[20, 100], - ).get_config() - assert output["value"]["chart"] == "bar" - config = json.loads(output["value"]["plot"]) - assert config["encoding"]["x"]["field"] == "b" - assert config["encoding"]["x"]["scale"] == {"domain": [20, 100]} - assert config["encoding"]["x"]["title"] == "Variable B" - - assert config["encoding"]["y"]["field"] == "a" - assert config["encoding"]["y"]["title"] == "Variable A" - - def test_barplot_accepts_fn_as_value(self): - plot = gr.BarPlot( - value=lambda: barley.sample(frac=0.1, replace=False), - x="year", - y="yield", - ) - assert isinstance(plot.value, dict) - assert isinstance(plot.value["plot"], str) diff --git a/test/components/test_line_plot.py b/test/components/test_line_plot.py deleted file mode 100644 index cd0a4f22e6..0000000000 --- a/test/components/test_line_plot.py +++ /dev/null @@ -1,112 +0,0 @@ -import json -from unittest.mock import MagicMock, patch - -import gradio as gr - -from .plot_data import stocks - - -class TestLinePlot: - @patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")}) - def test_get_config(self): - assert gr.LinePlot().get_config() == { - "caption": None, - "elem_id": None, - "elem_classes": [], - "interactive": None, - "label": None, - "name": "lineplot", - "bokeh_version": "3.0.3", - "show_actions_button": False, - "proxy_url": None, - "show_label": False, - "container": True, - "min_width": 160, - "scale": None, - "value": None, - "visible": True, - "x": None, - "y": None, - "color": None, - "stroke_dash": None, - "overlay_point": None, - "title": None, - "tooltip": [], - "x_title": None, - "y_title": None, - "color_legend_title": None, - "stroke_dash_legend_title": None, - "color_legend_position": None, - "stroke_dash_legend_position": None, - "height": None, - "width": None, - "x_lim": None, - "y_lim": None, - "x_label_angle": None, - "y_label_angle": None, - "_selectable": False, - "key": None, - } - - def test_no_color(self): - plot = gr.LinePlot( - x="date", - y="price", - tooltip=["symbol", "price"], - title="Stock Performance", - x_title="Trading Day", - ) - output = plot.postprocess(stocks).model_dump() # type: ignore - assert sorted(output.keys()) == ["chart", "plot", "type"] - config = json.loads(output["plot"]) - for layer in config["layer"]: - assert layer["mark"]["type"] in ["line", "point"] - assert layer["encoding"]["x"]["field"] == "date" - assert layer["encoding"]["x"]["title"] == "Trading Day" - assert layer["encoding"]["y"]["field"] == "price" - - assert config["title"] == "Stock Performance" - assert "height" not in config - assert "width" not in config - - def test_height_width(self): - plot = gr.LinePlot(x="date", y="price", height=100, width=200) - output = plot.postprocess(stocks).model_dump() # type: ignore - assert sorted(output.keys()) == ["chart", "plot", "type"] - config = json.loads(output["plot"]) - assert config["height"] == 100 - assert config["width"] == 200 - - def test_xlim_ylim(self): - plot = gr.LinePlot(x="date", y="price", x_lim=[200, 400], y_lim=[300, 500]) - output = plot.postprocess(stocks).model_dump() # type: ignore - config = json.loads(output["plot"]) - for layer in config["layer"]: - assert layer["encoding"]["x"]["scale"] == {"domain": [200, 400]} - assert layer["encoding"]["y"]["scale"] == {"domain": [300, 500]} - - def test_color_encoding(self): - plot = gr.LinePlot( - x="date", y="price", tooltip="symbol", color="symbol", overlay_point=True - ) - output = plot.postprocess(stocks).model_dump() # type: ignore - config = json.loads(output["plot"]) - for layer in config["layer"]: - assert layer["encoding"]["color"]["field"] == "symbol" - assert layer["encoding"]["color"]["scale"] == { - "domain": ["MSFT", "AMZN", "IBM", "GOOG", "AAPL"], - "range": [0, 1, 2, 3, 4], - } - assert layer["encoding"]["color"]["type"] == "nominal" - if layer["mark"]["type"] == "point": - assert layer["encoding"]["opacity"] == {} - - def test_lineplot_accepts_fn_as_value(self): - plot = gr.LinePlot( - value=lambda: stocks.sample(frac=0.1, replace=False), - x="date", - y="price", - color="symbol", - ) - assert isinstance(plot.value, dict) - assert isinstance(plot.value["plot"], str) diff --git a/test/components/test_native_plots.py b/test/components/test_native_plots.py new file mode 100644 index 0000000000..c099de209b --- /dev/null +++ b/test/components/test_native_plots.py @@ -0,0 +1,32 @@ +import gradio as gr + +from .plot_data import barley, simple + + +class TestNativePlot: + def test_plot_recognizes_correct_datatypes(self): + plot = gr.BarPlot( + value=simple, + x="date", + y="b", + ) + assert plot.value["datatypes"]["date"] == "temporal" + assert plot.value["datatypes"]["b"] == "quantitative" + + plot = gr.BarPlot( + value=simple, + x="a", + y="b", + color="c", + ) + assert plot.value["datatypes"]["a"] == "nominal" + assert plot.value["datatypes"]["b"] == "quantitative" + assert plot.value["datatypes"]["c"] == "quantitative" + + def test_plot_accepts_fn_as_value(self): + plot = gr.BarPlot( + value=lambda: barley.sample(frac=0.1, replace=False), + x="year", + y="yield", + ) + assert plot.value["mark"] == "bar" diff --git a/test/components/test_scatter_plot.py b/test/components/test_scatter_plot.py deleted file mode 100644 index 1369553d3b..0000000000 --- a/test/components/test_scatter_plot.py +++ /dev/null @@ -1,168 +0,0 @@ -import json -from unittest.mock import MagicMock, patch - -import gradio as gr - -from .plot_data import cars - - -class TestScatterPlot: - @patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")}) - def test_get_config(self): - print(gr.ScatterPlot().get_config()) - assert gr.ScatterPlot().get_config() == { - "caption": None, - "elem_id": None, - "elem_classes": [], - "interactive": None, - "label": None, - "name": "scatterplot", - "bokeh_version": "3.0.3", - "show_actions_button": False, - "proxy_url": None, - "show_label": False, - "container": True, - "min_width": 160, - "scale": None, - "value": None, - "visible": True, - "x": None, - "y": None, - "color": None, - "size": None, - "shape": None, - "title": None, - "tooltip": None, - "x_title": None, - "y_title": None, - "color_legend_title": None, - "size_legend_title": None, - "shape_legend_title": None, - "color_legend_position": None, - "size_legend_position": None, - "shape_legend_position": None, - "height": None, - "width": None, - "x_lim": None, - "y_lim": None, - "x_label_angle": None, - "y_label_angle": None, - "_selectable": False, - "key": None, - } - - def test_no_color(self): - plot = gr.ScatterPlot( - x="Horsepower", - y="Miles_per_Gallon", - tooltip="Name", - title="Car Data", - x_title="Horse", - ) - output = plot.postprocess(cars).model_dump() # type: ignore - assert sorted(output.keys()) == ["chart", "plot", "type"] - config = json.loads(output["plot"]) - assert config["encoding"]["x"]["field"] == "Horsepower" - assert config["encoding"]["x"]["title"] == "Horse" - assert config["encoding"]["y"]["field"] == "Miles_per_Gallon" - assert config["title"] == "Car Data" - assert "height" not in config - assert "width" not in config - - def test_no_interactive(self): - plot = gr.ScatterPlot( - x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False - ) - output = plot.postprocess(cars).model_dump() # type: ignore - assert sorted(output.keys()) == ["chart", "plot", "type"] - config = json.loads(output["plot"]) - assert "selection" not in config - - def test_height_width(self): - plot = gr.ScatterPlot( - x="Horsepower", y="Miles_per_Gallon", height=100, width=200 - ) - output = plot.postprocess(cars).model_dump() # type: ignore - assert sorted(output.keys()) == ["chart", "plot", "type"] - config = json.loads(output["plot"]) - assert config["height"] == 100 - assert config["width"] == 200 - - def test_xlim_ylim(self): - plot = gr.ScatterPlot( - x="Horsepower", y="Miles_per_Gallon", x_lim=[200, 400], y_lim=[300, 500] - ) - output = plot.postprocess(cars).model_dump() # type: ignore - config = json.loads(output["plot"]) - assert config["encoding"]["x"]["scale"] == {"domain": [200, 400]} - assert config["encoding"]["y"]["scale"] == {"domain": [300, 500]} - - def test_color_encoding(self): - plot = gr.ScatterPlot( - x="Horsepower", - y="Miles_per_Gallon", - tooltip="Name", - title="Car Data", - color="Origin", - ) - output = plot.postprocess(cars).model_dump() # type: ignore - config = json.loads(output["plot"]) - assert config["encoding"]["color"]["field"] == "Origin" - assert config["encoding"]["color"]["scale"] == { - "domain": ["USA", "Europe", "Japan"], - "range": [0, 1, 2], - } - assert config["encoding"]["color"]["type"] == "nominal" - - def test_two_encodings(self): - plot = gr.ScatterPlot( - show_label=False, - title="Two encodings", - x="Horsepower", - y="Miles_per_Gallon", - color="Acceleration", - shape="Origin", - ) - output = plot.postprocess(cars).model_dump() # type: ignore - config = json.loads(output["plot"]) - assert config["encoding"]["color"]["field"] == "Acceleration" - assert config["encoding"]["color"]["scale"] == { - "domain": [cars.Acceleration.min(), cars.Acceleration.max()], - "range": [0, 1], - } - assert config["encoding"]["color"]["type"] == "quantitative" - - assert config["encoding"]["shape"]["field"] == "Origin" - assert config["encoding"]["shape"]["type"] == "nominal" - - def test_legend_position(self): - plot = gr.ScatterPlot( - show_label=False, - title="Two encodings", - x="Horsepower", - y="Miles_per_Gallon", - color="Acceleration", - color_legend_position="none", - color_legend_title="Foo", - shape="Origin", - shape_legend_position="none", - shape_legend_title="Bar", - size="Acceleration", - size_legend_title="Accel", - size_legend_position="none", - ) - output = plot.postprocess(cars).model_dump() # type: ignore - config = json.loads(output["plot"]) - assert config["encoding"]["color"]["legend"] is None - assert config["encoding"]["shape"]["legend"] is None - assert config["encoding"]["size"]["legend"] is None - - def test_scatterplot_accepts_fn_as_value(self): - plot = gr.ScatterPlot( - value=lambda: cars.sample(frac=0.1, replace=False), - x="Horsepower", - y="Miles_per_Gallon", - color="Origin", - ) - assert isinstance(plot.value, dict) - assert isinstance(plot.value["plot"], str)