basic text fixed

This commit is contained in:
Abubakar Abid 2020-07-08 16:53:24 -05:00
parent 9913523170
commit 78cdc465d9
2 changed files with 7 additions and 4 deletions

View File

@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
from time import sleep
def answer_question(quantity, animal, place, activity_list, morning, etc): def answer_question(quantity, animal, place, activity_list, morning, etc):
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}""", "OK" return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}""", "OK"
@ -15,8 +15,8 @@ gr.Interface(answer_question,
gr.inputs.Textbox(default="What else?"), gr.inputs.Textbox(default="What else?"),
], ],
[ [
gr.outputs.Textbox(lines=8), gr.outputs.Textbox(),
gr.outputs.Textbox(lines=1), gr.outputs.Textbox(),
], ],
examples=[ examples=[
[2, "cat", "park", ["ran", "swam"], True], [2, "cat", "park", ["ran", "swam"], True],

View File

@ -5,7 +5,7 @@ from tensorflow.keras.layers import *
import gradio as gr import gradio as gr
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data() (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train.reshape(-1,784) / 255.0, x_test.reshape(-1,784) / 255.0 x_train, x_test = x_train.reshape(-1, 784) / 255.0, x_test.reshape(-1,784) / 255.0
def get_trained_model(n): def get_trained_model(n):
model = tf.keras.models.Sequential() model = tf.keras.models.Sequential()
@ -23,6 +23,7 @@ def get_trained_model(n):
print(model.evaluate(x_test, y_test)) print(model.evaluate(x_test, y_test))
return model return model
if not os.path.exists("models/mnist.h5"): if not os.path.exists("models/mnist.h5"):
model = get_trained_model(n=50000) model = get_trained_model(n=50000)
model.save('models/mnist.h5') model.save('models/mnist.h5')
@ -32,12 +33,14 @@ else:
graph = tf.get_default_graph() graph = tf.get_default_graph()
sess = tf.keras.backend.get_session() sess = tf.keras.backend.get_session()
def recognize_digit(image): def recognize_digit(image):
with graph.as_default(): with graph.as_default():
with sess.as_default(): with sess.as_default():
prediction = model.predict(image).tolist()[0] prediction = model.predict(image).tolist()[0]
return {str(i): prediction[i] for i in range(10)} return {str(i): prediction[i] for i in range(10)}
gr.Interface( gr.Interface(
recognize_digit, recognize_digit,
gradio.inputs.Sketchpad(flatten=True), gradio.inputs.Sketchpad(flatten=True),