Plot Component (#805)

* plotly + matplotlib component

* update plot demos and plotly component

* fix gray bg

* format

* pnpm lock file

* add bokeh

* update plot demo

* add bokeh support

* ignore plot file

* fixed demo

* fixed sorting

* update image-plot deprecation warning

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Dawood Khan 2022-04-14 17:45:31 -04:00 committed by GitHub
parent b17afde9ff
commit 7552e1ef6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 203 additions and 19 deletions

View File

@ -1,2 +1,3 @@
numpy
matplotlib
matplotlib
bokeh

View File

@ -1,43 +1,72 @@
from math import sqrt
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import pandas as pd
import bokeh.plotting as bk
from bokeh.models import ColumnDataSource
from bokeh.embed import json_item
import gradio as gr
def outbreak(r, month, countries, social_distancing):
def outbreak(plot_type, r, month, countries, social_distancing):
months = ["January", "February", "March", "April", "May"]
m = months.index(month)
start_day = 30 * m
final_day = 30 * (m + 1)
x = np.arange(start_day, final_day + 1)
day_count = x.shape[0]
pop_count = {"USA": 350, "Canada": 40, "Mexico": 300, "UK": 120}
r = sqrt(r)
if social_distancing:
r = sqrt(r)
for i, country in enumerate(countries):
series = x ** (r) * (i + 1)
plt.plot(x, series)
plt.title("Outbreak in " + month)
plt.ylabel("Cases")
plt.xlabel("Days since Day 0")
plt.legend(countries)
return plt
df = pd.DataFrame({'day': x})
for country in countries:
df[country] = ( x ** (r) * (pop_count[country] + 1))
if plot_type == "Matplotlib":
fig = plt.figure()
plt.plot(df['day'], df[countries])
plt.title("Outbreak in " + month)
plt.ylabel("Cases")
plt.xlabel("Days since Day 0")
plt.legend(countries)
return fig
elif plot_type == "Plotly":
fig = px.line(df, x='day', y=countries)
fig.update_layout(title="Outbreak in " + month,
xaxis_title="Cases",
yaxis_title="Days Since Day 0")
return fig
else:
source = ColumnDataSource(df)
p = bk.figure(title="Outbreak in " + month, x_axis_label="Cases", y_axis_label="Days Since Day 0")
for country in countries:
p.line(x='day', y=country, line_width=2, source=source)
item_text = json_item(p, "plotDiv")
return item_text
iface = gr.Interface(
outbreak,
[
gr.inputs.Dropdown(
["Matplotlib", "Plotly", "Bokeh"], label="Plot Type"
),
gr.inputs.Slider(1, 4, default=3.2, label="R"),
gr.inputs.Dropdown(
["January", "February", "March", "April", "May"], label="Month"
),
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries"),
gr.inputs.CheckboxGroup(["USA", "Canada", "Mexico", "UK"], label="Countries",
default=["USA", "Canada"]),
gr.inputs.Checkbox(label="Social Distancing?"),
],
"plot",
gr.outputs.Plot(type="auto"),
)
if __name__ == "__main__":
iface.launch()

View File

@ -15,9 +15,11 @@ from numbers import Number
from types import ModuleType
from typing import TYPE_CHECKING, Dict, List, Optional
import matplotlib
import numpy as np
import pandas as pd
import PIL
from black import out
from ffmpy import FFmpeg
from gradio import processing_utils
@ -210,12 +212,12 @@ class Image(OutputComponent):
"""
Parameters:
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
plot (bool): DEPRECATED (Use the new 'plot' component). Whether to expect a plot to be returned by the function.
label (str): component name in interface.
"""
if plot:
warnings.warn(
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
"The 'plot' parameter has been deprecated. Use the new 'plot' component instead.",
DeprecationWarning,
)
self.type = "plot"
@ -853,6 +855,70 @@ class State(OutputComponent):
}
class Plot(OutputComponent):
"""
Used for plot output.
Output type: matplotlib plt or plotly figure
Demos: outbreak_forecast
"""
def __init__(self, type: str = None, label: Optional[str] = None):
"""
Parameters:
type (str): type of plot (matplotlib, plotly)
label (str): component name in interface.
"""
self.type = type
super().__init__(label)
def get_template_context(self):
return {**super().get_template_context()}
@classmethod
def get_shortcut_implementations(cls):
return {
"plot": {},
}
def postprocess(self, y):
"""
Parameters:
y (str): plot data
Returns:
(str): plot type
(str): plot base64 or json
"""
dtype = self.type
if self.type == "plotly":
out_y = y.to_json()
elif self.type == "matplotlib":
out_y = processing_utils.encode_plot_to_base64(y)
elif self.type == "bokeh":
out_y = json.dumps(y)
elif self.type == "auto":
if isinstance(y, (ModuleType, matplotlib.pyplot.Figure)):
dtype = "matplotlib"
out_y = processing_utils.encode_plot_to_base64(y)
elif isinstance(y, dict):
dtype = "bokeh"
out_y = json.dumps(y)
else:
dtype = "plotly"
out_y = y.to_json()
else:
raise ValueError(
"Unknown type. Please choose from: 'plotly', 'matplotlib', 'bokeh'."
)
return {"type": dtype, "plot": out_y}
def deserialize(self, x):
y = processing_utils.decode_base64_to_file(x).name
return y
def save_flagged(self, dir, label, data, encryption_key):
return self.save_flagged_file(dir, label, data, encryption_key)
class Image3D(OutputComponent):
"""
Used for 3d image model output.

View File

@ -1,4 +1,5 @@
packages/app/public/**
pnpm-workspace.yaml
packages/app/dist/**
pnpm-lock.yaml
pnpm-lock.yaml
packages/app/src/components/output/Plot/Plot.svelte

View File

@ -20,9 +20,10 @@
"svelte": "^3.46.3",
"svelte-check": "^2.4.1",
"svelte-i18n": "^3.3.13",
"vitest": "^0.3.2",
"plotly.js-dist-min": "^2.10.1",
"babylonjs": "^4.2.1",
"babylonjs-loaders": "^4.2.1",
"vitest": "^0.3.2"
"babylonjs-loaders": "^4.2.1"
},
"devDependencies": {
"@types/three": "^0.138.0"

View File

@ -26,6 +26,7 @@ import OutputTextbox from "./output/Textbox/config.js";
import OutputVideo from "./output/Video/config.js";
import OutputTimeSeries from "./output/TimeSeries/config.js";
import OutputChatbot from "./output/Chatbot/config.js";
import OutputPlot from "./output/Plot/config.js";
import OutputImage3D from "./output/Image3D/config.js";
import StaticButton from "./static/Button/config.js";
@ -62,6 +63,7 @@ export const output_component_map = {
timeseries: OutputTimeSeries,
video: OutputVideo,
chatbot: OutputChatbot,
plot: OutputPlot,
image3d: OutputImage3D
};

View File

@ -0,0 +1,73 @@
<svelte:head>
<!-- Loading Bokeh from CDN -->
<script src="https://cdn.bokeh.org/bokeh/release/bokeh-2.4.2.min.js" on:load={handleBokehLoaded} ></script>
{#if bokehLoaded}
<script src="https://cdn.pydata.org/bokeh/release/bokeh-widgets-2.4.2.min.js" on:load={() => initializeBokeh(1)} ></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-tables-2.4.2.min.js" on:load={() => initializeBokeh(2)}></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-gl-2.4.2.min.js" on:load={() => initializeBokeh(3)}></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js" on:load={() => initializeBokeh(4)}></script>
<script src="https://cdn.pydata.org/bokeh/release/bokeh-api-2.4.2.min.js" on:load={() => initializeBokeh(5)} ></script>
{/if}
</svelte:head>
<script lang="ts">
export let value: string;
export let theme: string;
import { afterUpdate, onMount} from "svelte";
import Plotly from "plotly.js-dist-min";
// Bokeh
let bokehLoaded = false
const resolves = []
const bokehPromises = Array(6).fill(0).map((_, i) => createPromise(i))
const initializeBokeh = (index) => {
if (value["type"] == "bokeh") {
console.log(resolves)
resolves[index]()
}
}
function createPromise(index) {
return new Promise((resolve, reject) => {
resolves[index] = resolve
})
}
function handleBokehLoaded() {
initializeBokeh(0)
bokehLoaded = true
}
Promise.all(bokehPromises).then(() => {
let plotObj = JSON.parse(value["plot"]);
window.Bokeh.embed.embed_item(plotObj, "plotDiv");
})
// Plotly
afterUpdate(() => {
if (value["type"] == "plotly") {
let plotObj = JSON.parse(value["plot"]);
let plotDiv = document.getElementById("plotDiv");
Plotly.newPlot(plotDiv, plotObj["data"], plotObj["layout"]);
} else if (value["type"] == "bokeh") {
let plotObj = JSON.parse(value["plot"]);
window.Bokeh.embed.embed_item(plotObj, "plotDiv");
}
});
</script>
{#if value["type"] == "plotly" || value["type"] == "bokeh" }
<div id="plotDiv" />
{:else}
<div
class="output-image w-full h-80 flex justify-center items-center dark:bg-gray-600 relative"
{theme}
>
<!-- svelte-ignore a11y-missing-attribute -->
<img class="w-full h-full object-contain" src={value["plot"]} />
</div>
{/if}
<style lang="postcss">
</style>

View File

@ -0,0 +1,5 @@
import Component from "./Plot.svelte";
export default {
component: Component
};

6
ui/pnpm-lock.yaml generated
View File

@ -7,6 +7,7 @@ importers:
'@types/three': ^0.138.0
babylonjs: ^4.2.1
babylonjs-loaders: ^4.2.1
plotly.js-dist-min: ^2.10.1
prettier: ^2.5.1
prettier-plugin-svelte: ^2.6.0
svelte: ^3.46.3
@ -16,6 +17,7 @@ importers:
dependencies:
babylonjs: 4.2.1
babylonjs-loaders: 4.2.1
plotly.js-dist-min: 2.10.1
prettier: 2.5.1
prettier-plugin-svelte: 2.6.0_prettier@2.5.1+svelte@3.46.3
svelte: 3.46.3
@ -1766,6 +1768,10 @@ packages:
resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==}
engines: {node: '>=8.6'}
/plotly.js-dist-min/2.10.1:
resolution: {integrity: sha512-H0ls1C2uu2U+qWw76djo4/zOGtUKfMILwFhu7tCOaG/wH5ypujrYGCH03N9SQVf1SXcctTfW57USf8LmagSiPQ==}
dev: false
/pn/1.1.0:
resolution: {integrity: sha512-2qHaIQr2VLRFoxe2nASzsV6ef4yOOH+Fi9FBOVH6cqeSgUnoyySPZkxzLuzd+RYOQTRpROA0ztTMqxROKSb/nA==}
dev: false