Springboot 整合 Java DL4J 实现时尚穿搭推荐系统(实例代码)
作者:月下独码
Spring Boot 整合 Java Deeplearning4j 实现时尚穿搭推荐系统
一、引言
在当今时尚潮流不断变化的时代,人们对于个性化的穿搭需求越来越高。为了满足用户的这一需求,我们可以利用深度学习技术,通过分析用户上传的照片,为用户推荐适合的服装搭配。本文将介绍如何使用 Spring Boot 整合 Java Deeplearning4j 来实现一个时尚穿搭推荐系统。
二、技术概述
- Spring Boot:Spring Boot 是一个用于快速开发 Java 应用程序的框架。它简化了 Spring 应用程序的配置和部署,使得开发人员可以更加专注于业务逻辑的实现。
- Deeplearning4j:Deeplearning4j 是一个用于深度学习的 Java 库。它支持多种深度学习算法,包括卷积神经网络(CNN)、**循环神经网络(RNN)**等。在本案例中,我们将使用 Deeplearning4j 来实现图像识别功能。
- 神经网络选择:在本案例中,我们选择使用**卷积神经网络(CNN)**来实现图像识别功能。CNN 是一种专门用于处理图像数据的神经网络,它具有良好的图像识别能力。选择 CNN 的理由如下:
- 局部感知:
CNN可以自动学习图像中的局部特征,从而更好地识别图像中的物体。 - 权值共享:
CNN中的卷积层可以共享权值,从而减少了模型的参数数量,提高了模型的训练效率。 - 多层结构:
CNN通常由多个卷积层和池化层组成,这种多层结构可以提取图像中的不同层次的特征,从而提高了模型的识别准确率。
- 局部感知:
三、数据集格式
- 数据集来源:我们可以从时尚杂志、时尚博客等渠道收集时尚穿搭的图片作为我们的数据集。也可以使用公开的时尚穿搭数据集,如 Fashion-MNIST 数据集等。
- 数据集格式:我们将数据集存储为图像文件的形式,每个图像文件代表一个时尚穿搭的示例。图像文件的命名格式为“穿搭类型_肤色_身材.jpg”,例如“休闲装_白皙_苗条.jpg”。
- 数据集目录结构:我们将数据集存储在一个目录中,目录结构如下:
dataset/
|--休闲装/
| |--白皙/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
| |--小麦色/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
| |--古铜色/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
|--正装/
| |--白皙/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
| |--小麦色/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
| |--古铜色/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
|--运动装/
| |--白皙/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
| |--小麦色/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg
| |--古铜色/
| | |--苗条.jpg
| | |--中等.jpg
| | |--丰满.jpg四、Maven 依赖
Spring Boot 依赖:
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>这个依赖包含了 Spring Boot 的 Web 开发所需的组件,如 Spring MVC、Tomcat 等。
2. Deeplearning4j 依赖:
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>1.0.0-beta7</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>1.0.0-beta7</version>
</dependency>这些依赖包含了 Deeplearning4j 的核心库、神经网络库和用户界面库。
3. 其他依赖:
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
<version>3.1.0</version>
</dependency>
<dependency>
<groupId>commons-fileupload</groupId>
<artifactId>commons-fileupload</artifactId>
<version>1.3.3</version>
</dependency>这些依赖包含了 Servlet API 和文件上传组件,用于处理用户上传的照片。
五、代码示例
模型训练代码:
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GraphVertex;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
public class FashionRecommendationModel {
private ComputationGraph model;
public FashionRecommendationModel() {
// 定义神经网络结构
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.convolutionMode(ConvolutionMode.Same)
.updater("adam")
.l2(0.0005);
// 输入层
int height = 224;
int width = 224;
int channels = 3;
builder.graphBuilder()
.addInputs("input")
.setInputTypes(InputType.convolutional(height, width, channels));
// 卷积层 1
builder.addLayer("conv1", new ConvolutionLayer.Builder(3, 3)
.nIn(channels)
.nOut(32)
.stride(1, 1)
.build(), "input");
// 池化层 1
builder.addLayer("pool1", new ConvolutionLayer.Builder(2, 2)
.stride(2, 2)
.build(), "conv1");
// 卷积层 2
builder.addLayer("conv2", new ConvolutionLayer.Builder(3, 3)
.nOut(64)
.stride(1, 1)
.build(), "pool1");
// 池化层 2
builder.addLayer("pool2", new ConvolutionLayer.Builder(2, 2)
.stride(2, 2)
.build(), "conv2");
// 卷积层 3
builder.addLayer("conv3", new ConvolutionLayer.Builder(3, 3)
.nOut(128)
.stride(1, 1)
.build(), "pool2");
// 池化层 3
builder.addLayer("pool3", new ConvolutionLayer.Builder(2, 2)
.stride(2, 2)
.build(), "conv3");
// 全连接层 1
int numClasses = 3; // 假设穿搭类型有 3 种
int numNodes = 1024;
builder.addLayer("fc1", new DenseLayer.Builder()
.nOut(numNodes)
.activation(Activation.RELU)
.build(), "pool3");
// 全连接层 2
builder.addLayer("fc2", new DenseLayer.Builder()
.nOut(numClasses)
.activation(Activation.SOFTMAX)
.build(), "fc1");
// 输出层
builder.addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(numClasses)
.activation(Activation.SOFTMAX)
.build(), "fc2");
// 构建计算图
model = new ComputationGraph(builder.build());
model.init();
}
public void trainModel(String datasetPath) {
// 加载数据集
List<INDArray> images = new ArrayList<>();
List<Integer> labels = new ArrayList<>();
File datasetDir = new File(datasetPath);
for (File categoryDir : datasetDir.listFiles()) {
int label = Integer.parseInt(categoryDir.getName());
for (File skinToneDir : categoryDir.listFiles()) {
for (File bodyShapeDir : skinToneDir.listFiles()) {
for (File imageFile : bodyShapeDir.listFiles()) {
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
INDArray image = loader.asMatrix(imageFile);
images.add(image);
labels.add(label);
}
}
}
}
// 数据归一化
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
for (INDArray image : images) {
scaler.transform(image);
}
// 转换为 ND4J 的数据集格式
INDArray inputData = Nd4j.create(images.size(), 3, 224, 224);
INDArray labelData = Nd4j.create(images.size(), 3);
for (int i = 0; i < images.size(); i++) {
inputData.putRow(i, images.get(i));
labelData.putScalar(i, labels.get(i), 1.0);
}
// 训练模型
model.fit(inputData, labelData);
}
public int predict(INDArray image) {
// 数据归一化
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
scaler.transform(image);
// 预测
INDArray output = model.outputSingle(image);
int prediction = Nd4j.argMax(output, 1).getInt(0);
return prediction;
}
}这段代码定义了一个FashionRecommendationModel类,用于训练和预测时尚穿搭类型。在构造函数中,定义了一个卷积神经网络的结构,包括输入层、卷积层、池化层、全连接层和输出层。在trainModel方法中,加载数据集并进行数据归一化,然后将数据转换为 ND4J 的数据集格式,最后使用计算图进行训练。在predict方法中,对输入的图像进行数据归一化,然后使用训练好的模型进行预测,返回预测的穿搭类型。
2.Spring Boot 服务代码:
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
@SpringBootApplication
@RestController
public class FashionRecommendationApp {
private FashionRecommendationModel model;
public FashionRecommendationApp() {
model = new FashionRecommendationModel();
model.trainModel("dataset");
}
@PostMapping("/recommend")
public String recommend(@RequestParam("image") MultipartFile imageFile) throws IOException {
// 读取上传的图像文件
InputStream inputStream = new ByteArrayInputStream(imageFile.getBytes());
BufferedImage image = ImageIO.read(inputStream);
// 将图像转换为 ND4J 的数组格式
org.datavec.image.loader.NativeImageLoader loader = new org.datavec.image.loader.NativeImageLoader(224, 224, 3);
INDArray imageArray = loader.asMatrix(image);
// 使用模型进行预测
int prediction = model.predict(imageArray);
// 返回预测的穿搭建议
switch (prediction) {
case 0:
return "休闲装";
case 1:
return "正装";
case 2:
return "运动装";
default:
return "无法识别";
}
}
public static void main(String[] args) {
SpringApplication.run(FashionRecommendationApp.class, args);
}
}这段代码定义了一个 Spring Boot 应用程序,用于提供时尚穿搭推荐服务。在构造函数中,创建了一个FashionRecommendationModel对象,并使用数据集进行训练。在recommend方法中,处理用户上传的图像文件,将其转换为 ND4J 的数组格式,然后使用训练好的模型进行预测,最后返回预测的穿搭建议。
六、单元测试
模型训练测试:
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.jupiter.api.Assertions.assertEquals;
class FashionRecommendationModelTest {
private FashionRecommendationModel model;
@BeforeEach
void setUp() {
model = new FashionRecommendationModel();
}
@Test
void testTrainModel() {
model.trainModel("dataset");
ComputationGraph trainedModel = model.getModel();
assertNotNull(trainedModel);
}
@Test
void testPredict() {
model.trainModel("dataset");
// 加载测试图像
org.datavec.image.loader.NativeImageLoader loader = new org.datavec.image.loader.NativeImageLoader(224, 224, 3);
INDArray testImage = loader.asMatrix(new File("test_image.jpg"));
int prediction = model.predict(testImage);
// 根据测试图像的实际穿搭类型进行断言
assertEquals(0, prediction);
}
}这段代码对FashionRecommendationModel类进行了单元测试。在testTrainModel方法中,测试了模型的训练方法,确保训练后的模型不为空。在testPredict方法中,加载一个测试图像,使用训练好的模型进行预测,并根据测试图像的实际穿搭类型进行断言。
服务测试:
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.result.MockMvcResultMatchers;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import java.io.FileInputStream;
@SpringBootTest
class FashionRecommendationAppTest {
private MockMvc mockMvc;
@Test
void testRecommend() throws Exception {
FashionRecommendationApp app = new FashionRecommendationApp();
mockMvc = MockMvcBuilders.standaloneSetup(app).build();
// 加载测试图像
FileInputStream fis = new FileInputStream("test_image.jpg");
MockMultipartFile imageFile = new MockMultipartFile("image", "test_image.jpg", "image/jpeg", fis);
// 发送 POST 请求进行测试
mockMvc.perform(MockMvcRequestBuilders.multipart("/recommend")
.file(imageFile))
.andExpect(MockMvcResultMatchers.status().isOk())
.andExpect(MockMvcResultMatchers.content().string("休闲装"));
}
}这段代码对FashionRecommendationApp类进行了单元测试。在testRecommend方法中,使用MockMvc模拟发送 POST 请求,上传一个测试图像,并断言返回的穿搭建议是否正确。
七、预期输出
- 模型训练成功后,控制台会输出训练过程中的损失值等信息。
- 当用户上传一张照片时,服务会返回一个穿搭建议,如“休闲装”、“正装”或“运动装”。
八、参考资料
到此这篇关于Springboot 整合 Java DL4J 实现时尚穿搭推荐系统的文章就介绍到这了,更多相关Springboot Java DL4J 时尚穿搭推荐系统内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
