from keras.utils import multi_gpu_model
# 0 < gpus <= 你真实含有的gpu数量
parallel_model = multi_gpu_model(model, gpus=8)
# 进行编译
parallel_model.compile(loss='categorical_crossentropy',
optimizer='rmsprop')
# fit 方法会在8个gpu上分布运行,就是将数据一个batch里面的数据平均分到8个gpu上
# 由于batch_size = 256,每块gpu有256/8数据在运行
parallel_model.fit(x, y, epochs=20, batch_size=256)
#加载模型的时候,先将模型创建起来
model = create_model()
model = multi_gpu_model(model , gpus = 8)
tensorflow==1.8.0 -> tensorflow==1.7.0