这篇文章主要介绍ImageDataGenerator和flow()有什么用,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!

ImageDataGenerator的参数自己看文档

from keras.preprocessing import imageimport numpy as npX_train=np.ones((3,123,123,1))Y_train=np.array([[1],[2],[2]])generator=image.ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, zca_epsilon=1e-6, rotation_range=180, width_shift_range=0.2, height_shift_range=0.2, shear_range=0, zoom_range=0.001, channel_shift_range=0, fill_mode='nearest', cval=0., horizontal_flip=True, vertical_flip=True, rescale=None, preprocessing_function=None, data_format='channels_last')a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环'''batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机,输出的X,Y是一一对对应的如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配'''X,Y=next(a)print(Y)X,Y=next(a)print(Y)X,Y=next(a)print(Y)X,Y=next(a)

输出

[[2] [1] [2]][[2] [2] [1]][[2] [2] [1]][[2] [2] [1]]

补充知识:tensorflow 与keras 混用之坑

在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下

其中错误为:TypeError: tuple indices must be integers, not list

再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下

原训练代码

from tensorflow.python.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.python.keras.models import Sequentialfrom tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalizationfrom tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense #Каталог с данными для обученияtrain_dir = 'train'# Каталог с данными для проверкиval_dir = 'val'# Каталог с данными для тестированияtest_dir = 'val' # Размеры изображенияimg_width, img_height = 800, 800# Размерность тензора на основе изображения для входных данных в нейронную сеть# backend Tensorflow, channels_lastinput_shape = (img_width, img_height, 3)# Количество эпохepochs = 1# Размер мини-выборкиbatch_size = 4# Количество изображений для обученияnb_train_samples = 300# Количество изображений для проверкиnb_validation_samples = 25# Количество изображений для тестированияnb_test_samples = 25 model = Sequential() model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(10, 10))) model.add(Conv2D(64, (5, 5), padding="same"))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(10, 10))) model.add(Flatten())model.add(Dense(512))model.add(Activation('relu'))model.add(Dropout(0.5))model.add(Dense(10, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer="Nadam", metrics=['accuracy'])print(model.summary())datagen = ImageDataGenerator(rescale=1. / 255) train_generator = datagen.flow_from_directory( train_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') val_generator = datagen.flow_from_directory( val_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') test_generator = datagen.flow_from_directory( test_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') model.fit_generator( train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=epochs, validation_data=val_generator, validation_steps=nb_validation_samples // batch_size) print('Сохраняем сеть')model.save("grib.h6")print("Сохранение завершено!")

模型载入

from tensorflow.python.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.python.keras.models import Sequentialfrom tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalizationfrom tensorflow.python.keras.layers import Activation, Dropout, Flatten, Densefrom keras.models import load_model print("Загрузка сети")model = load_model("grib.h6")print("Загрузка завершена!")

报错

/usr/bin/python3.5 /home/disk2/py/neroset/do.py/home/mama/.local/lib/python3.5/site-packages/h6py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_convertersUsing TensorFlow backend.Загрузка сетиTraceback (most recent call last): File "/home/disk2/py/neroset/do.py", line 13, in <module> model = load_model("grib.h6") File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model model = model_from_config(model_config, custom_objects=custom_objects) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config return layer_module.deserialize(config, custom_objects=custom_objects) File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize printable_module_name='layer') File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object list(custom_objects.items()))) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config model.add(layer) File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add output_tensor = layer(self.outputs[0]) File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__ self.build(input_shapes[0]) File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build dim = input_shape[self.axis]TypeError: tuple indices must be integers or slices, not list Process finished with exit code 1

战斗种族解释

убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибкаВыяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)

强调文本 强调文本

keras.preprocessing.image import ImageDataGeneratorkeras.models import Sequentialkeras.layers import Conv2D, MaxPooling2D, BatchNormalizationkeras.layers import Activation, Dropout, Flatten, Dense

以上是ImageDataGenerator和flow()有什么用的所有内容,感谢各位的阅读!希望分享的内容对大家有帮助,更多相关知识,欢迎关注亿速云行业资讯频道!