keras建模的3种方式详解
作者:月疯
这篇文章主要介绍了keras建模的3种方式详解,keras是Google公司于2016年发布的以tensorflow为后端的用于深度学习网络训练的高阶API,因接口设计非常人性化,深受程序员的喜爱,需要的朋友可以参考下
keras建模的3种方式
keras是google公司2016年发布的tensorflow为后端的深度学习网络的高级接口。
三种建模方式:
- 序列模型
- 函数模型
- 子类模型
第一种序列模型:
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Sequential from keras.models import load_model from keras.layers import Dense #加载数据 def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #序列模型 def DNN(train_x,train_y,valid_x,valid_y): #創建模型 model=Sequential() model.add(Dense(64,input_dim=784,activation='relu')) model.add(Dense(128,activation='relu')) model.add(Dense(10,activation='softmax')) #查看网络模型 model.summary() #编译模型 model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) #训练模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #保存模型 model.save('sequential.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') DNN(train_x,train_y,valid_x,valid_y) model=load_model('sequential.h5') #下载模型 pre=model.predict(test_x) #测试验证 #计算验证集精度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
第二种函数模型
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Model from keras.models import load_model from keras.layers import Input,Dense #加载数据 def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #函数模型 def DNN(train_x,train_y,valid_x,valid_y): #创建模型 inputs=Input(shape=(784,)) x=Dense(64,activation='relu')(inputs) x=Dense(128,activation='relu')(x) output=Dense(10,activation='softmax')(x) model=Model(input=inputs,output=output) #查看网络结构 model.summary() #编译模型 model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) #训练模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #保存模型 model.save('fun_model.h5') train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') DNN(train_x,train_y,valid_x,valid_y) model=load_model('fun_model.h5') #下载模型 pre=model.predict(test_x) #验证数据集 #验证数据集准确度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
第三种子类模型
import numpy as np from tensorflow.examples.tutorials.mnist import input_data from keras.models import Model from keras.layers import Dense #加载数据 def read_data(path): mnist=input_data.read_data_sets(path,one_hot=True) train_x,train_y=mnist.train.images,mnist.train.labels, valid_x,valid_y=mnist.validation.images,mnist.validation.labels, test_x,test_y=mnist.test.images,mnist.test.labels return train_x,train_y,valid_x,valid_y,test_x,test_y #子类模型 class DNN(Model): def __init__(self,train_x,train_y,valid_x,valid_y): super(DNN,self).__init__() #初始化网络模型 self.dense1=Dense(64,input_dim=784,activation='relu') self.dense2=Dense(128,activation='relu') self.dense3=Dense(10,activation='softmax') def call(self,inputs): #回调順序 x=self.dense1(inputs) x=self.dense2(x) x=self.dense3(x) return x train_x,train_y,valid_x,valid_y,test_x,test_y=read_data('MNIST_data') model=DNN(train_x,train_y,valid_x,valid_y) #编译模型(学习率、损失函数、模型评估) model.compile(optimizer='adam(lr=0.001)',loss='categorical_crossentropy',metrics=['accuracy']) #训练模型 model.fit(train_x,train_y,batch_size=500,nb_epoch=100,verbose=1,validation_data=(valid_x,valid_y)) #查看网络结构 model.summary() pre=model.predict(test_x) #验证数据集 #计算验证数据集的准确度 a=np.argmax(pre,1) b=np.argmax(test_y,1) t=(a==b).astype(int) acc=np.sum(t)/len(a) print(acc)
常用的损失函数:
mse #均方差(回归)
mae #绝对误差(回归)
binary_crossentropy #二值交叉熵(二分类,逻辑回归)
categorical_crossentropy #交叉熵(多分类)
到此这篇关于keras建模的3种方式详解的文章就介绍到这了,更多相关keras建模方式内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!