mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
basic text fixed
This commit is contained in:
parent
9913523170
commit
78cdc465d9
@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
def answer_question(quantity, animal, place, activity_list, morning, etc):
|
def answer_question(quantity, animal, place, activity_list, morning, etc):
|
||||||
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}""", "OK"
|
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}""", "OK"
|
||||||
@ -15,8 +15,8 @@ gr.Interface(answer_question,
|
|||||||
gr.inputs.Textbox(default="What else?"),
|
gr.inputs.Textbox(default="What else?"),
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
gr.outputs.Textbox(lines=8),
|
gr.outputs.Textbox(),
|
||||||
gr.outputs.Textbox(lines=1),
|
gr.outputs.Textbox(),
|
||||||
],
|
],
|
||||||
examples=[
|
examples=[
|
||||||
[2, "cat", "park", ["ran", "swam"], True],
|
[2, "cat", "park", ["ran", "swam"], True],
|
||||||
|
@ -5,7 +5,7 @@ from tensorflow.keras.layers import *
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||||
x_train, x_test = x_train.reshape(-1,784) / 255.0, x_test.reshape(-1,784) / 255.0
|
x_train, x_test = x_train.reshape(-1, 784) / 255.0, x_test.reshape(-1,784) / 255.0
|
||||||
|
|
||||||
def get_trained_model(n):
|
def get_trained_model(n):
|
||||||
model = tf.keras.models.Sequential()
|
model = tf.keras.models.Sequential()
|
||||||
@ -23,6 +23,7 @@ def get_trained_model(n):
|
|||||||
print(model.evaluate(x_test, y_test))
|
print(model.evaluate(x_test, y_test))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if not os.path.exists("models/mnist.h5"):
|
if not os.path.exists("models/mnist.h5"):
|
||||||
model = get_trained_model(n=50000)
|
model = get_trained_model(n=50000)
|
||||||
model.save('models/mnist.h5')
|
model.save('models/mnist.h5')
|
||||||
@ -32,12 +33,14 @@ else:
|
|||||||
graph = tf.get_default_graph()
|
graph = tf.get_default_graph()
|
||||||
sess = tf.keras.backend.get_session()
|
sess = tf.keras.backend.get_session()
|
||||||
|
|
||||||
|
|
||||||
def recognize_digit(image):
|
def recognize_digit(image):
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
with sess.as_default():
|
with sess.as_default():
|
||||||
prediction = model.predict(image).tolist()[0]
|
prediction = model.predict(image).tolist()[0]
|
||||||
return {str(i): prediction[i] for i in range(10)}
|
return {str(i): prediction[i] for i in range(10)}
|
||||||
|
|
||||||
|
|
||||||
gr.Interface(
|
gr.Interface(
|
||||||
recognize_digit,
|
recognize_digit,
|
||||||
gradio.inputs.Sketchpad(flatten=True),
|
gradio.inputs.Sketchpad(flatten=True),
|
||||||
|
Loading…
Reference in New Issue
Block a user