mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Plot Component (#805)
* plotly + matplotlib component * update plot demos and plotly component * fix gray bg * format * pnpm lock file * add bokeh * update plot demo * add bokeh support * ignore plot file * fixed demo * fixed sorting * update image-plot deprecation warning Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
b17afde9ff
commit
7552e1ef6c
@ -1,2 +1,3 @@
|
||||
numpy
|
||||
matplotlib
|
||||
matplotlib
|
||||
bokeh
|
||||
|
@ -1,43 +1,72 @@
|
||||
from math import sqrt
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import plotly.express as px
|
||||
import pandas as pd
|
||||
import bokeh.plotting as bk
|
||||
from bokeh.models import ColumnDataSource
|
||||
from bokeh.embed import json_item
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def outbreak(r, month, countries, social_distancing):
|
||||
def outbreak(plot_type, r, month, countries, social_distancing):
|
||||
months = ["January", "February", "March", "April", "May"]
|
||||
m = months.index(month)
|
||||
start_day = 30 * m
|
||||
final_day = 30 * (m + 1)
|
||||
x = np.arange(start_day, final_day + 1)
|
||||
day_count = x.shape[0]
|
||||
pop_count = {"USA": 350, "Canada": 40, "Mexico": 300, "UK": 120}
|
||||
r = sqrt(r)
|
||||
if social_distancing:
|
||||
r = sqrt(r)
|
||||
for i, country in enumerate(countries):
|
||||
series = x ** (r) * (i + 1)
|
||||
plt.plot(x, series)
|
||||
plt.title("Outbreak in " + month)
|
||||
plt.ylabel("Cases")
|
||||
plt.xlabel("Days since Day 0")
|
||||
plt.legend(countries)
|
||||
return plt
|
||||
df = pd.DataFrame({'day': x})
|
||||
for country in countries:
|
||||
df[country] = ( x ** (r) * (pop_count[country] + 1))
|
||||
|
||||
|
||||
if plot_type == "Matplotlib":
|
||||
fig = plt.figure()
|
||||
plt.plot(df['day'], df[countries])
|
||||
plt.title("Outbreak in " + month)
|
||||
plt.ylabel("Cases")
|
||||
plt.xlabel("Days since Day 0")
|
||||
plt.legend(countries)
|
||||
return fig
|
||||
elif plot_type == "Plotly":
|
||||
fig = px.line(df, x='day', y=countries)
|
||||
fig.update_layout(title="Outbreak in " + month,
|
||||
xaxis_title="Cases",
|
||||
yaxis_title="Days Since Day 0")
|
||||
return fig
|
||||
else:
|
||||
source = ColumnDataSource(df)
|
||||
p = bk.figure(title="Outbreak in " + month, x_axis_label="Cases", y_axis_label="Days Since Day 0")
|
||||
for country in countries:
|
||||
p.line(x='day', y=country, line_width=2, source=source)
|
||||
item_text = json_item(p, "plotDiv")
|
||||
return item_text
|
||||
|
||||
|
||||
|
||||
iface = gr.Interface(
|
||||
outbreak,
|
||||
[
|
||||
gr.inputs.Dropdown(
|
||||
["Matplotlib", "Plotly", "Bokeh"], label="Plot Type"
|
||||
),
|
||||
gr.inputs.Slider(1, 4, default=3.2, label="R"),
|
||||
gr.inputs.Dropdown(
|
||||
["January", "February", "March", "April", "May"], label="Month"
|
||||
),
|
||||
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
|
||||
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries",
|
||||
default=["USA", "Canada"]),
|
||||
gr.inputs.Checkbox(label="Social Distancing?"),
|
||||
],
|
||||
"plot",
|
||||
gr.outputs.Plot(type="auto"),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -15,9 +15,11 @@ from numbers import Number
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
from black import out
|
||||
from ffmpy import FFmpeg
|
||||
|
||||
from gradio import processing_utils
|
||||
@ -210,12 +212,12 @@ class Image(OutputComponent):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
plot (bool): DEPRECATED (Use the new 'plot' component). Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
"""
|
||||
if plot:
|
||||
warnings.warn(
|
||||
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
|
||||
"The 'plot' parameter has been deprecated. Use the new 'plot' component instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.type = "plot"
|
||||
@ -853,6 +855,70 @@ class State(OutputComponent):
|
||||
}
|
||||
|
||||
|
||||
class Plot(OutputComponent):
|
||||
"""
|
||||
Used for plot output.
|
||||
Output type: matplotlib plt or plotly figure
|
||||
Demos: outbreak_forecast
|
||||
"""
|
||||
|
||||
def __init__(self, type: str = None, label: Optional[str] = None):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): type of plot (matplotlib, plotly)
|
||||
label (str): component name in interface.
|
||||
"""
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {**super().get_template_context()}
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"plot": {},
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (str): plot data
|
||||
Returns:
|
||||
(str): plot type
|
||||
(str): plot base64 or json
|
||||
"""
|
||||
dtype = self.type
|
||||
if self.type == "plotly":
|
||||
out_y = y.to_json()
|
||||
elif self.type == "matplotlib":
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
elif self.type == "bokeh":
|
||||
out_y = json.dumps(y)
|
||||
elif self.type == "auto":
|
||||
if isinstance(y, (ModuleType, matplotlib.pyplot.Figure)):
|
||||
dtype = "matplotlib"
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
elif isinstance(y, dict):
|
||||
dtype = "bokeh"
|
||||
out_y = json.dumps(y)
|
||||
else:
|
||||
dtype = "plotly"
|
||||
out_y = y.to_json()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type. Please choose from: 'plotly', 'matplotlib', 'bokeh'."
|
||||
)
|
||||
return {"type": dtype, "plot": out_y}
|
||||
|
||||
def deserialize(self, x):
|
||||
y = processing_utils.decode_base64_to_file(x).name
|
||||
return y
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
||||
|
||||
|
||||
class Image3D(OutputComponent):
|
||||
"""
|
||||
Used for 3d image model output.
|
||||
|
@ -1,4 +1,5 @@
|
||||
packages/app/public/**
|
||||
pnpm-workspace.yaml
|
||||
packages/app/dist/**
|
||||
pnpm-lock.yaml
|
||||
pnpm-lock.yaml
|
||||
packages/app/src/components/output/Plot/Plot.svelte
|
@ -20,9 +20,10 @@
|
||||
"svelte": "^3.46.3",
|
||||
"svelte-check": "^2.4.1",
|
||||
"svelte-i18n": "^3.3.13",
|
||||
"vitest": "^0.3.2",
|
||||
"plotly.js-dist-min": "^2.10.1",
|
||||
"babylonjs": "^4.2.1",
|
||||
"babylonjs-loaders": "^4.2.1",
|
||||
"vitest": "^0.3.2"
|
||||
"babylonjs-loaders": "^4.2.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/three": "^0.138.0"
|
||||
|
@ -26,6 +26,7 @@ import OutputTextbox from "./output/Textbox/config.js";
|
||||
import OutputVideo from "./output/Video/config.js";
|
||||
import OutputTimeSeries from "./output/TimeSeries/config.js";
|
||||
import OutputChatbot from "./output/Chatbot/config.js";
|
||||
import OutputPlot from "./output/Plot/config.js";
|
||||
import OutputImage3D from "./output/Image3D/config.js";
|
||||
|
||||
import StaticButton from "./static/Button/config.js";
|
||||
@ -62,6 +63,7 @@ export const output_component_map = {
|
||||
timeseries: OutputTimeSeries,
|
||||
video: OutputVideo,
|
||||
chatbot: OutputChatbot,
|
||||
plot: OutputPlot,
|
||||
image3d: OutputImage3D
|
||||
};
|
||||
|
||||
|
73
ui/packages/app/src/components/output/Plot/Plot.svelte
Normal file
73
ui/packages/app/src/components/output/Plot/Plot.svelte
Normal file
@ -0,0 +1,73 @@
|
||||
<svelte:head>
|
||||
<!-- Loading Bokeh from CDN -->
|
||||
<script src="https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js" on:load={handleBokehLoaded} ></script>
|
||||
{#if bokehLoaded}
|
||||
<script src="https://cdn.pydata.org/bokeh/release/bokeh-widgets-2.4.2.min.js" on:load={() => initializeBokeh(1)} ></script>
|
||||
<script src="https://cdn.pydata.org/bokeh/release/bokeh-tables-2.4.2.min.js" on:load={() => initializeBokeh(2)}></script>
|
||||
<script src="https://cdn.pydata.org/bokeh/release/bokeh-gl-2.4.2.min.js" on:load={() => initializeBokeh(3)}></script>
|
||||
<script src="https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js" on:load={() => initializeBokeh(4)}></script>
|
||||
<script src="https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js" on:load={() => initializeBokeh(5)} ></script>
|
||||
{/if}
|
||||
</svelte:head>
|
||||
|
||||
<script lang="ts">
|
||||
export let value: string;
|
||||
export let theme: string;
|
||||
import { afterUpdate, onMount} from "svelte";
|
||||
import Plotly from "plotly.js-dist-min";
|
||||
|
||||
// Bokeh
|
||||
let bokehLoaded = false
|
||||
const resolves = []
|
||||
const bokehPromises = Array(6).fill(0).map((_, i) => createPromise(i))
|
||||
|
||||
const initializeBokeh = (index) => {
|
||||
if (value["type"] == "bokeh") {
|
||||
console.log(resolves)
|
||||
resolves[index]()
|
||||
}
|
||||
}
|
||||
|
||||
function createPromise(index) {
|
||||
return new Promise((resolve, reject) => {
|
||||
resolves[index] = resolve
|
||||
})
|
||||
}
|
||||
|
||||
function handleBokehLoaded() {
|
||||
initializeBokeh(0)
|
||||
bokehLoaded = true
|
||||
}
|
||||
|
||||
Promise.all(bokehPromises).then(() => {
|
||||
let plotObj = JSON.parse(value["plot"]);
|
||||
window.Bokeh.embed.embed_item(plotObj, "plotDiv");
|
||||
})
|
||||
|
||||
// Plotly
|
||||
afterUpdate(() => {
|
||||
if (value["type"] == "plotly") {
|
||||
let plotObj = JSON.parse(value["plot"]);
|
||||
let plotDiv = document.getElementById("plotDiv");
|
||||
Plotly.newPlot(plotDiv, plotObj["data"], plotObj["layout"]);
|
||||
} else if (value["type"] == "bokeh") {
|
||||
let plotObj = JSON.parse(value["plot"]);
|
||||
window.Bokeh.embed.embed_item(plotObj, "plotDiv");
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if value["type"] == "plotly" || value["type"] == "bokeh" }
|
||||
<div id="plotDiv" />
|
||||
{:else}
|
||||
<div
|
||||
class="output-image w-full h-80 flex justify-center items-center dark:bg-gray-600 relative"
|
||||
{theme}
|
||||
>
|
||||
<!-- svelte-ignore a11y-missing-attribute -->
|
||||
<img class="w-full h-full object-contain" src={value["plot"]} />
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style lang="postcss">
|
||||
</style>
|
5
ui/packages/app/src/components/output/Plot/config.ts
Normal file
5
ui/packages/app/src/components/output/Plot/config.ts
Normal file
@ -0,0 +1,5 @@
|
||||
import Component from "./Plot.svelte";
|
||||
|
||||
export default {
|
||||
component: Component
|
||||
};
|
6
ui/pnpm-lock.yaml
generated
6
ui/pnpm-lock.yaml
generated
@ -7,6 +7,7 @@ importers:
|
||||
'@types/three': ^0.138.0
|
||||
babylonjs: ^4.2.1
|
||||
babylonjs-loaders: ^4.2.1
|
||||
plotly.js-dist-min: ^2.10.1
|
||||
prettier: ^2.5.1
|
||||
prettier-plugin-svelte: ^2.6.0
|
||||
svelte: ^3.46.3
|
||||
@ -16,6 +17,7 @@ importers:
|
||||
dependencies:
|
||||
babylonjs: 4.2.1
|
||||
babylonjs-loaders: 4.2.1
|
||||
plotly.js-dist-min: 2.10.1
|
||||
prettier: 2.5.1
|
||||
prettier-plugin-svelte: 2.6.0_prettier@2.5.1+svelte@3.46.3
|
||||
svelte: 3.46.3
|
||||
@ -1766,6 +1768,10 @@ packages:
|
||||
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
|
||||
engines: {node: '>=8.6'}
|
||||
|
||||
/plotly.js-dist-min/2.10.1:
|
||||
resolution: {integrity: sha512-H0ls1C2uu2U+qWw76djo4/zOGtUKfMILwFhu7tCOaG/wH5ypujrYGCH03N9SQVf1SXcctTfW57USf8LmagSiPQ==}
|
||||
dev: false
|
||||
|
||||
/pn/1.1.0:
|
||||
resolution: {integrity: sha512-2qHaIQr2VLRFoxe2nASzsV6ef4yOOH+Fi9FBOVH6cqeSgUnoyySPZkxzLuzd+RYOQTRpROA0ztTMqxROKSb/nA==}
|
||||
dev: false
|
||||
|
Loading…
x
Reference in New Issue
Block a user