Use read_sql in database guide (#2604)

* Move demo to gradio + fix code

* Foo

* Address comments
This commit is contained in:
Freddy Boulton 2022-11-04 16:31:08 -04:00 committed by GitHub
parent 218fb9fa65
commit f795a4d3b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 134 additions and 44 deletions

View File

@ -7,7 +7,7 @@ No changes to highlight.
No changes to highlight.
## Documentation Changes:
No changes to highlight.
* Modified the "Connecting To a Database Guide" to use `pd.read_sql` as opposed to low-level postgres connector by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2604](https://github.com/gradio-app/gradio/pull/2604)
## Testing and Infrastructure Changes:
No changes to highlight.

View File

@ -0,0 +1,2 @@
psycopg2
matplotlib

View File

@ -0,0 +1,88 @@
import os
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
matplotlib.use("Agg")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
DB_HOST = os.getenv("DB_HOST")
PORT = 8080
DB_NAME = "bikeshare"
connection_string = (
f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}"
)
def get_count_ride_type():
df = pd.read_sql(
"""
SELECT COUNT(ride_id) as n, rideable_type
FROM rides
GROUP BY rideable_type
ORDER BY n DESC
""",
con=connection_string,
)
fig_m, ax = plt.subplots()
ax.bar(x=df["rideable_type"], height=df["n"])
ax.set_title("Number of rides by bycycle type")
ax.set_ylabel("Number of Rides")
ax.set_xlabel("Bicycle Type")
return fig_m
def get_most_popular_stations():
df = pd.read_sql(
"""
SELECT COUNT(ride_id) as n, MAX(start_station_name) as station
FROM RIDES
WHERE start_station_name is NOT NULL
GROUP BY start_station_id
ORDER BY n DESC
LIMIT 5
""",
con=connection_string,
)
fig_m, ax = plt.subplots()
ax.bar(x=df["station"], height=df["n"])
ax.set_title("Most popular stations")
ax.set_ylabel("Number of Rides")
ax.set_xlabel("Station Name")
ax.set_xticklabels(df["station"], rotation=45, ha="right", rotation_mode="anchor")
ax.tick_params(axis="x", labelsize=8)
fig_m.tight_layout()
return fig_m
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chicago Bike Share Dashboard
This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.
This demo uses psycopg2 but any postgresql client library (SQLAlchemy)
is compatible with gradio.
Connection credentials are handled by environment variables
defined as secrets in the Space.
If data were added to the database, the plots in this demo would update
whenever the webpage is reloaded.
This demo serves as a starting point for your database-connected apps!
"""
)
with gr.Row():
bike_type = gr.Plot()
station = gr.Plot()
demo.load(get_count_ride_type, inputs=None, outputs=bike_type)
demo.load(get_most_popular_stations, inputs=None, outputs=station)
if __name__ == "__main__":
demo.launch()

View File

@ -1,6 +1,6 @@
# Connecting to a Database
Related spaces: https://huggingface.co/spaces/freddyaboulton/chicago-bike-share-dashboard
Related spaces: https://huggingface.co/spaces/gradio/chicago-bike-share-dashboard
Tags: TABULAR, PLOTS
## Introduction
@ -20,7 +20,7 @@ Our goal is to create a dashboard that will enable our business stakeholders to
At the end of this guide, we will have a functioning application that looks like this:
<gradio-app space="freddyaboulton/chicago-bike-share-dashboard"> </gradio-app>
<gradio-app space="gradio/chicago-bike-share-dashboard"> </gradio-app>
## Step 1 - Creating your database
@ -34,54 +34,45 @@ RDS will not let you create a postgreSQL instance on ports 80 or 443.
Once your database is created, download the dataset from Kaggle and upload it to your database.
For the sake of this demo, we will only upload March 2022 data.
## Step 2.a - Connect to your database!
We will be using the `psycopg2` postgreSQL driver for python.
This is not a hard requirement, so you can use your preferred driver, like SQLAlchemy.
The first step is to create a connection to the database. With `psycopg2` you can do so like this:
```python
connection = psycopg2.connect(user=os.environ["DB_USER"],
password=os.environ["DB_PASSWORD"],
host=os.environ["DB_HOST"],
port="8080",
database="bikeshare")
```
We will be passing the database username, password, and host as environment variables.
This will make our app more secure by avoiding storing sensitive information as plain text in our application files.
If you were to run our script locally, you could pass in your credentials as environment variables like so
```bash
DB_USER='username' DB_PASSWORD='password' DB_HOST='host' python app.py
```
## Step 2.b - Write your ETL code
## Step 2.a - Write your ETL code
We will be querying our database for the total count of rides split by the type of bicycle (electric, standard, or docked).
We will also query for the total count of rides that depart from each station and take the top 5.
We will then take the result of our queries and visualize them in with matplotlib.
We can do this with the following code:
We will use the pandas [read_sql](https://pandas.pydata.org/docs/reference/api/pandas.read_sql.html)
method to connect to the database. This requires the `psycopg2` library to be installed.
In order to connect to our database, we will specify the database username, password, and host as environment variables.
This will make our app more secure by avoiding storing sensitive information as plain text in our application files.
```python
import os
import pandas as pd
import matplotlib.pyplot as plt
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
DB_HOST = os.getenv("DB_HOST")
PORT = 8080
DB_NAME = "bikeshare"
connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}"
def get_count_ride_type():
cursor = connection.cursor()
cursor.execute(
"""
df = pd.read_sql(
"""
SELECT COUNT(ride_id) as n, rideable_type
FROM rides
GROUP BY rideable_type
ORDER BY n DESC
"""
""",
con=connection_string
)
rides = cursor.fetchall()
cursor.close()
fig_m, ax = plt.subplots()
ax.bar(x=[s[1] for s in rides], height=[s[0] for s in rides])
ax.bar(x=df['rideable_type'], height=df['n'])
ax.set_title("Number of rides by bycycle type")
ax.set_ylabel("Number of Rides")
ax.set_xlabel("Bicycle Type")
@ -89,37 +80,46 @@ def get_count_ride_type():
def get_most_popular_stations():
cursor = connection.cursor()
cursor.execute(
df = pd.read_sql(
"""
SELECT COUNT(ride_id) as n, MAX(start_station_name)
SELECT COUNT(ride_id) as n, MAX(start_station_name) as station
FROM RIDES
WHERE start_station_name is NOT NULL
GROUP BY start_station_id
ORDER BY n DESC
LIMIT 5
"""
""",
con=connection_string
)
stations = cursor.fetchall()
fig_m, ax = plt.subplots()
ax.bar(x=[s[1] for s in stations], height=[s[0] for s in stations])
ax.bar(x=df['station'], height=df['n'])
ax.set_title("Most popular stations")
ax.set_ylabel("Number of Rides")
ax.set_xlabel("Station Name")
ax.set_xticklabels(
[s[1] for s in stations], rotation=45, ha="right", rotation_mode="anchor"
df['station'], rotation=45, ha="right", rotation_mode="anchor"
)
ax.tick_params(axis="x", labelsize=8)
fig_m.tight_layout()
return fig_m
```
If you were to run our script locally, you could pass in your credentials as environment variables like so
```bash
DB_USER='username' DB_PASSWORD='password' DB_HOST='host' python app.py
```
## Step 2.c - Write your gradio app
We will display or matplotlib plots in two separate `gr.Plot` components displayed side by side using `gr.Row()`.
Because we have wrapped our function to fetch the data in a `demo.load()` event trigger,
our demo will fetch the latest data **dynamically** from the database each time the web page loads. 🪄
```python
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
bike_type = gr.Plot()
@ -146,7 +146,7 @@ You will have to add the `DB_USER`, `DB_PASSWORD`, and `DB_HOST` variables as "R
## Conclusion
Congratulations! You know how to connect your gradio app to a database hosted on the cloud! ☁️
Our dashboard is now running on [Spaces](https://huggingface.co/spaces/freddyaboulton/chicago-bike-share-dashboard).
The complete code is [here](https://huggingface.co/spaces/freddyaboulton/chicago-bike-share-dashboard/blob/main/app.py)
Our dashboard is now running on [Spaces](https://huggingface.co/spaces/gradio/chicago-bike-share-dashboard).
The complete code is [here](https://huggingface.co/spaces/gradio/chicago-bike-share-dashboard/blob/main/app.py)
As you can see, gradio gives you the power to connect to your data wherever it lives and display however you want! 🔥