mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
5.4 KiB
5.4 KiB
In [1]:
%load_ext autoreload %autoreload 2 import tensorflow as tf import gradio
In [2]:
n_classes = 10 (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train, x_test = x_train.reshape(-1, 784) / 255.0, x_test.reshape(-1, 784) / 255.0 y_train = tf.keras.utils.to_categorical(y_train, n_classes).astype(float) y_test = tf.keras.utils.to_categorical(y_test, n_classes).astype(float) learning_rate = 0.5 epochs = 5 batch_size = 100
In [3]:
x = tf.placeholder(tf.float32, [None, 784], name="x") y = tf.placeholder(tf.float32, [None, 10], name="y") W1 = tf.Variable(tf.random_normal([784, 300], stddev=0.03), name='W1') b1 = tf.Variable(tf.random_normal([300]), name='b1') W2 = tf.Variable(tf.random_normal([300, 10], stddev=0.03), name='W2') hidden_out = tf.add(tf.matmul(x, W1), b1) hidden_out = tf.nn.relu(hidden_out) y_ = tf.matmul(hidden_out, W2)
WARNING:tensorflow:From C:\Users\ALI\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer.
In [4]:
probs = tf.nn.softmax(y_) cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_, labels=y)) optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cross_entropy)
In [5]:
init_op = tf.global_variables_initializer() correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
In [6]:
sess = tf.Session() sess.run(init_op) total_batch = int(len(y_train) / batch_size) for epoch in range(epochs): avg_cost = 0 for start, end in zip(range(0, len(y_train), batch_size), range(batch_size, len(y_train)+1, batch_size)): batch_x = x_train[start: end] batch_y = y_train[start: end] _, c = sess.run([optimizer, cross_entropy], feed_dict={x: batch_x, y: batch_y}) avg_cost += c / total_batch
In [7]:
def predict(inp): return sess.run(probs, feed_dict={x:inp})
In [8]:
inp = gradio.inputs.Sketchpad(flatten=True) io = gradio.Interface(inputs=inp, outputs="label", model_type="pyfunc", model=predict)
In [9]:
io.launch(share=True)
No validation samples for this interface... skipping validation. NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu Model is running locally at: http://localhost:7860/interface.html Model available publicly for 8 hours at: https://share.gradio.app/25d5d472
Out[9]:
(<gradio.networking.serve_files_in_background.<locals>.HTTPServer at 0x229f97553c8>, 'http://localhost:7860/', 'http://25d5d472.ngrok.io')
In [ ]: