mirror of
https://github.com/tencentmusic/cube-studio.git
synced 2024-12-15 06:09:57 +08:00
96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
|
|
import sys
|
|
|
|
import tensorflow as tf
|
|
|
|
import horovod
|
|
import horovod.tensorflow as hvd
|
|
|
|
|
|
def main():
|
|
# Horovod: initialize Horovod.
|
|
hvd.init()
|
|
|
|
# Horovod: pin GPU to be used to process local rank (one GPU per process)
|
|
gpus = tf.config.experimental.list_physical_devices('GPU')
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
if gpus:
|
|
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
|
|
|
|
(mnist_images, mnist_labels), _ = \
|
|
tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())
|
|
|
|
dataset = tf.data.Dataset.from_tensor_slices(
|
|
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
|
|
tf.cast(mnist_labels, tf.int64))
|
|
)
|
|
dataset = dataset.repeat().shuffle(10000).batch(128)
|
|
|
|
mnist_model = tf.keras.Sequential([
|
|
tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
|
|
tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
|
|
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
|
|
tf.keras.layers.Dropout(0.25),
|
|
tf.keras.layers.Flatten(),
|
|
tf.keras.layers.Dense(128, activation='relu'),
|
|
tf.keras.layers.Dropout(0.5),
|
|
tf.keras.layers.Dense(10, activation='softmax')
|
|
])
|
|
loss = tf.losses.SparseCategoricalCrossentropy()
|
|
|
|
# Horovod: adjust learning rate based on number of GPUs.
|
|
opt = tf.optimizers.Adam(0.001 * hvd.size())
|
|
|
|
checkpoint_dir = './checkpoints'
|
|
checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)
|
|
|
|
@tf.function
|
|
def training_step(images, labels, first_batch):
|
|
with tf.GradientTape() as tape:
|
|
probs = mnist_model(images, training=True)
|
|
loss_value = loss(labels, probs)
|
|
|
|
# Horovod: add Horovod Distributed GradientTape.
|
|
tape = hvd.DistributedGradientTape(tape)
|
|
|
|
grads = tape.gradient(loss_value, mnist_model.trainable_variables)
|
|
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
|
|
|
|
# Horovod: broadcast initial variable states from rank 0 to all other processes.
|
|
# This is necessary to ensure consistent initialization of all workers when
|
|
# training is started with random weights or restored from a checkpoint.
|
|
#
|
|
# Note: broadcast should be done after the first gradient step to ensure optimizer
|
|
# initialization.
|
|
if first_batch:
|
|
hvd.broadcast_variables(mnist_model.variables, root_rank=0)
|
|
hvd.broadcast_variables(opt.variables(), root_rank=0)
|
|
|
|
return loss_value
|
|
|
|
# Horovod: adjust number of steps based on number of GPUs.
|
|
for batch, (images, labels) in enumerate(dataset.take(10000 // hvd.size())):
|
|
loss_value = training_step(images, labels, batch == 0)
|
|
|
|
if batch % 10 == 0 and hvd.local_rank() == 0:
|
|
print('Step #%d\tLoss: %.6f' % (batch, loss_value))
|
|
|
|
# Horovod: save checkpoints only on worker 0 to prevent other workers from
|
|
# corrupting it.
|
|
if hvd.rank() == 0:
|
|
checkpoint.save(checkpoint_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) == 4:
|
|
# run training through horovod.run
|
|
np = int(sys.argv[1])
|
|
hosts = sys.argv[2]
|
|
comm = sys.argv[3]
|
|
print('Running training through horovod.run')
|
|
horovod.run(main, np=np, hosts=hosts, use_gloo=comm == 'gloo', use_mpi=comm == 'mpi')
|
|
else:
|
|
# this is running via horovodrun
|
|
main()
|