整合DL4J训练模型与Web工程


声明:本文转载自https://my.oschina.net/u/1778239/blog/1648854,转载目的在于传递更多信息,仅供学习交流之用。如有侵权行为,请联系我,我会及时删除。

一、前言

    上一篇博客《有趣的卷积神经网络》介绍如何基于deeplearning4j对手写数字识别进行训练,对于整个训练集只训练了一次,正确率是0.9897,随着迭代次数的增加,网络模型将更加逼近训练集,下面是对训练集迭代十次的评估结果,总之迭代次数的增加会更加逼近模型(注:增加迭代次数有时也会发生过拟合,有时候也并非很奏效,具体情况具体分析)。

 Accuracy:        0.9919
 Precision:       0.9919
 Recall:          0.9918
 F1 Score:        0.9918

二、导读

    1、web环境搭建

    2、基于canvas构建前端画图界面

    3、整合dl4j训练模型

三、web环境搭建

    1、eclipse  new一个Maven project ,填好maven坐标,packaging选war

<groupId>org.dl4j</groupId> <artifactId>digitalrecognition</artifactId> <version>0.0.1-SNAPSHOT</version> <packaging>war</packaging>

    2、配置Jar包依赖,由于servlet-api一般由web容器提供,所以scope为provided,这样不会被打入war包里。

<dependencies> 		<dependency> 			<groupId>org.springframework</groupId> 			<artifactId>spring-webmvc</artifactId> 			<version>4.3.4.RELEASE</version> 		</dependency> 		<dependency> 			<groupId>javax.servlet</groupId> 			<artifactId>servlet-api</artifactId> 			<version>2.5</version> 			<scope>provided</scope> 		</dependency> 		<dependency> 			<groupId>com.fasterxml.jackson.core</groupId> 			<artifactId>jackson-core</artifactId> 			<version>2.5.3</version> 		</dependency>  		<dependency> 			<groupId>com.fasterxml.jackson.core</groupId> 			<artifactId>jackson-annotations</artifactId> 			<version>2.5.3</version> 		</dependency>  		<dependency> 			<groupId>com.fasterxml.jackson.core</groupId> 			<artifactId>jackson-databind</artifactId> 			<version>2.5.3</version> 		</dependency> 		<dependency> 			<groupId>commons-fileupload</groupId> 			<artifactId>commons-fileupload</artifactId> 			<version>1.3.1</version> 		</dependency> 		<dependency> 			<groupId>org.deeplearning4j</groupId> 			<artifactId>deeplearning4j-core</artifactId> 			<version>0.9.1</version> 		</dependency> 		<dependency> 			<groupId>org.nd4j</groupId> 			<artifactId>nd4j-native-platform</artifactId> 			<version>0.9.1</version> 		</dependency> 	</dependencies>

    3、为了开发方便,不用把web工程部署到外置web容器,所以在开发时用mavan tomcat插件是比较方便的。运行时mvn tomcat7:run即可

<build> 		<plugins> 			<plugin> 				<groupId>org.apache.tomcat.maven</groupId> 				<artifactId>tomcat7-maven-plugin</artifactId> 				<version>2.2</version> 				<configuration> 					<uriEncoding>UTF-8</uriEncoding> 					<path>/</path> 					<port>8080</port> 					<protocol>org.apache.coyote.http11.Http11NioProtocol</protocol> 					<maxThreads>1000</maxThreads> 					<minSpareThreads>100</minSpareThreads> 				</configuration> 			</plugin> 		</plugins> 	</build>

    4、web常规配置web.xml,filter、servlet、listener这里就略去了。

四、前端canvas画图实现

    1、html元素、css

<style type="text/css"> body { 	padding: 0; 	margin: 0; 	background: white; }  #canvas { 	margin: 100px 0 0 300px; }  #canvas>span { 	color: white; 	font-size: 14px; }  #result { 	margin: 0px 0 0 300px; } </style> <html> <head> <title>数字识别</title> </head> <body> 	<canvas id="canvas" width="280" height="280"></canvas> 	<button onclick="predict()">预测</button> 	<div id="result"> 		识别结果:<font size="18" id="digit"></font> 	</div> </body> </html>

    2、js代码实现在canvas画布连线操作,并将图片转化为base64格式,ajax发送给后端,这里画布的大小是280px,所以图片到了后端,需要缩小至十分之一。

<script src="/js/jquery-3.2.1.min.js"></script> <script type="text/javascript"> 	/*获取绘制环境*/ 	var canvas = $('#canvas')[0].getContext('2d'); 	canvas.strokeStyle = "white";//线条的颜色 	canvas.lineWidth = 10;//线条粗细 	canvas.fillStyle = 'black' 	canvas.fillRect(0, 0, 280, 280); 	$('#canvas').on('mousedown', function() { 		/*开始绘制*/ 		canvas.beginPath(); 		/*设置动画绘制起点坐标*/ 		canvas.moveTo(event.pageX - 300, event.pageY - 100); 		$('#canvas').on('mousemove', function() { 			/*设置下一个点坐标*/ 			canvas.lineTo(event.pageX - 300, event.pageY - 100); 			/*画线*/ 			canvas.stroke(); 		}); 	}).on('mouseup', function() { 		$('#canvas').off('mousemove'); 	}); 	function predict() { 		var img = $('#canvas')[0].toDataURL("image/png"); 		$.ajax({ 			url : "/digitalRecognition/predict", 			type : "post", 			data : { 				"img" : img.substring(img.indexOf(",") + 1) 			}, 			success : function(response) { 				$("#digit").html(response); 			}, 			error : function() { 			} 		}); 	} </script>

    整体呈现的界面如下,可以画图。

五、后端java代码

@RequestMapping("/digitalRecognition") @Controller public class DigitalRecognitionController implements InitializingBean { 	private MultiLayerNetwork net;  	@ResponseBody 	@RequestMapping("/predict") 	public int predict(@RequestParam(value = "img") String img) throws Exception { 		String imagePath= generateImage(img);//将base64图片转化为png图片 		imagePath= zoomImage(imagePath);//将图片缩小至28*28 		DataNormalization scaler = new ImagePreProcessingScaler(0, 1); 		ImageRecordReader testRR = new ImageRecordReader(28, 28, 1); 		File testData = new File(imagePath); 		FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS); 		testRR.initialize(testSplit); 		DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, 1); 		testIter.setPreProcessor(scaler); 		INDArray array = testIter.next().getFeatureMatrix(); 		return net.predict(array)[0]; 	}  	private String generateImage(String img) { 		BASE64Decoder decoder = new BASE64Decoder(); 		String filePath = WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png"; 		try { 			byte[] b = decoder.decodeBuffer(img); 			for (int i = 0; i < b.length; ++i) { 				if (b[i] < 0) { 					b[i] += 256; 				} 			} 			OutputStream out = new FileOutputStream(filePath); 			out.write(b); 			out.flush(); 			out.close(); 		} catch (Exception e) { 			e.printStackTrace(); 		} 		return filePath; 	} 	 	private String zoomImage(String filePath){ 		String imagePath=WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png"; 		try { 			BufferedImage bufferedImage = ImageIO.read(new File(filePath)); 			Image image = bufferedImage.getScaledInstance(28, 28, Image.SCALE_SMOOTH); 			BufferedImage tag = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB); 			Graphics g = tag.getGraphics(); 			g.drawImage(image, 0, 0, null); // 绘制处理后的图 			g.dispose(); 			ImageIO.write(tag, "png",new File(imagePath)); 		} catch (Exception e) { 			e.printStackTrace(); 		} 		return imagePath; 	} 	  	@Override 	public void afterPropertiesSet() throws Exception { 		net = ModelSerializer.restoreMultiLayerNetwork(new File(WebConstant.WEB_ROOT + "model/minist-model.zip")); 	}  }

    代码说明:

    1、InitializingBean是spring bean生命周期中的一个环节,spring构建bean的过程中会执行afterPropertiesSet方法,这里用这个方法来加载已经定型的网络。

      2、generateImage是用来将前端传过来的base64串转化为png格式。

      3、zoomImage方法将前端的280*280缩小至28*28和训练数据一致,并存到webroot的upload目录下。

     4、predict进行预测,将转化好的28*28的图片读取出来,张量化,把像素点的值压缩至0到1,预测,最后结果是一个数组,由于只有一张图片,取数组的第一个元素即可。

六、测试,mvn tomcat7:run,浏览器访问http://localhost:8080即可玩手写数字识别了

    

           

    测试结果马马虎虎,大体上实现了基本功能。

    git地址:https://gitee.com/lxkm/dl4j-demo/tree/master/digitalrecognition

    快乐源于分享。

 

 

 

 

 

 

本文发表于2018年03月22日 22:38
(c)注:本文转载自https://my.oschina.net/u/1778239/blog/1648854,转载目的在于传递更多信息,并不代表本网赞同其观点和对其真实性负责。如有侵权行为,请联系我们,我们会及时删除.

阅读 3417 讨论 0 喜欢 0

抢先体验

扫码体验
趣味小程序
文字表情生成器

闪念胶囊

你要过得好哇,这样我才能恨你啊,你要是过得不好,我都不知道该恨你还是拥抱你啊。

直抵黄龙府,与诸君痛饮尔。

那时陪伴我的人啊,你们如今在何方。

不出意外的话,我们再也不会见了,祝你前程似锦。

这世界真好,吃野东西也要留出这条命来看看

快捷链接
网站地图
提交友链
Copyright © 2016 - 2021 Cion.
All Rights Reserved.
京ICP备2021004668号-1