java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > Springboot  Java DL4J 时尚穿搭推荐系统

Springboot 整合 Java DL4J 实现时尚穿搭推荐系统(实例代码)

作者:月下独码

本文介绍了如何使用SpringBoot和JavaDeeplearning4j框架搭建一个时尚穿搭推荐系统,文章详细阐述了系统的技术架构、数据集格式、Maven依赖配置、模型训练和预测代码实现,以及单元测试和预期输出结果

Spring Boot 整合 Java Deeplearning4j 实现时尚穿搭推荐系统

一、引言

在当今时尚潮流不断变化的时代,人们对于个性化的穿搭需求越来越高。为了满足用户的这一需求,我们可以利用深度学习技术,通过分析用户上传的照片,为用户推荐适合的服装搭配。本文将介绍如何使用 Spring Boot 整合 Java Deeplearning4j 来实现一个时尚穿搭推荐系统。

二、技术概述

三、数据集格式

dataset/
    |--休闲装/
    |   |--白皙/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |   |--小麦色/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |   |--古铜色/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |--正装/
    |   |--白皙/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |   |--小麦色/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |   |--古铜色/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |--运动装/
    |   |--白皙/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |   |--小麦色/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg
    |   |--古铜色/
    |   |   |--苗条.jpg
    |   |   |--中等.jpg
    |   |   |--丰满.jpg

四、Maven 依赖

Spring Boot 依赖

<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>

这个依赖包含了 Spring Boot 的 Web 开发所需的组件,如 Spring MVCTomcat 等。
2. Deeplearning4j 依赖

<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-nn</artifactId>
    <version>1.0.0-beta7</version>
</dependency>
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-ui</artifactId>
    <version>1.0.0-beta7</version>
</dependency>

这些依赖包含了 Deeplearning4j 的核心库、神经网络库和用户界面库。
3. 其他依赖

<dependency>
    <groupId>javax.servlet</groupId>
    <artifactId>javax.servlet-api</artifactId>
    <version>3.1.0</version>
</dependency>
<dependency>
    <groupId>commons-fileupload</groupId>
    <artifactId>commons-fileupload</artifactId>
    <version>1.3.3</version>
</dependency>

这些依赖包含了 Servlet API 和文件上传组件,用于处理用户上传的照片。

五、代码示例

模型训练代码

import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GraphVertex;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
public class FashionRecommendationModel {
    private ComputationGraph model;
    public FashionRecommendationModel() {
        // 定义神经网络结构
        NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
               .weightInit(WeightInit.XAVIER)
               .activation(Activation.RELU)
               .convolutionMode(ConvolutionMode.Same)
               .updater("adam")
               .l2(0.0005);
        // 输入层
        int height = 224;
        int width = 224;
        int channels = 3;
        builder.graphBuilder()
               .addInputs("input")
               .setInputTypes(InputType.convolutional(height, width, channels));
        // 卷积层 1
        builder.addLayer("conv1", new ConvolutionLayer.Builder(3, 3)
               .nIn(channels)
               .nOut(32)
               .stride(1, 1)
               .build(), "input");
        // 池化层 1
        builder.addLayer("pool1", new ConvolutionLayer.Builder(2, 2)
               .stride(2, 2)
               .build(), "conv1");
        // 卷积层 2
        builder.addLayer("conv2", new ConvolutionLayer.Builder(3, 3)
               .nOut(64)
               .stride(1, 1)
               .build(), "pool1");
        // 池化层 2
        builder.addLayer("pool2", new ConvolutionLayer.Builder(2, 2)
               .stride(2, 2)
               .build(), "conv2");
        // 卷积层 3
        builder.addLayer("conv3", new ConvolutionLayer.Builder(3, 3)
               .nOut(128)
               .stride(1, 1)
               .build(), "pool2");
        // 池化层 3
        builder.addLayer("pool3", new ConvolutionLayer.Builder(2, 2)
               .stride(2, 2)
               .build(), "conv3");
        // 全连接层 1
        int numClasses = 3; // 假设穿搭类型有 3 种
        int numNodes = 1024;
        builder.addLayer("fc1", new DenseLayer.Builder()
               .nOut(numNodes)
               .activation(Activation.RELU)
               .build(), "pool3");
        // 全连接层 2
        builder.addLayer("fc2", new DenseLayer.Builder()
               .nOut(numClasses)
               .activation(Activation.SOFTMAX)
               .build(), "fc1");
        // 输出层
        builder.addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
               .nOut(numClasses)
               .activation(Activation.SOFTMAX)
               .build(), "fc2");
        // 构建计算图
        model = new ComputationGraph(builder.build());
        model.init();
    }
    public void trainModel(String datasetPath) {
        // 加载数据集
        List<INDArray> images = new ArrayList<>();
        List<Integer> labels = new ArrayList<>();
        File datasetDir = new File(datasetPath);
        for (File categoryDir : datasetDir.listFiles()) {
            int label = Integer.parseInt(categoryDir.getName());
            for (File skinToneDir : categoryDir.listFiles()) {
                for (File bodyShapeDir : skinToneDir.listFiles()) {
                    for (File imageFile : bodyShapeDir.listFiles()) {
                        NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
                        INDArray image = loader.asMatrix(imageFile);
                        images.add(image);
                        labels.add(label);
                    }
                }
            }
        }
        // 数据归一化
        DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
        for (INDArray image : images) {
            scaler.transform(image);
        }
        // 转换为 ND4J 的数据集格式
        INDArray inputData = Nd4j.create(images.size(), 3, 224, 224);
        INDArray labelData = Nd4j.create(images.size(), 3);
        for (int i = 0; i < images.size(); i++) {
            inputData.putRow(i, images.get(i));
            labelData.putScalar(i, labels.get(i), 1.0);
        }
        // 训练模型
        model.fit(inputData, labelData);
    }
    public int predict(INDArray image) {
        // 数据归一化
        DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
        scaler.transform(image);
        // 预测
        INDArray output = model.outputSingle(image);
        int prediction = Nd4j.argMax(output, 1).getInt(0);
        return prediction;
    }
}

这段代码定义了一个FashionRecommendationModel类,用于训练和预测时尚穿搭类型。在构造函数中,定义了一个卷积神经网络的结构,包括输入层、卷积层、池化层、全连接层和输出层。在trainModel方法中,加载数据集并进行数据归一化,然后将数据转换为 ND4J 的数据集格式,最后使用计算图进行训练。在predict方法中,对输入的图像进行数据归一化,然后使用训练好的模型进行预测,返回预测的穿搭类型。

2.Spring Boot 服务代码

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
@SpringBootApplication
@RestController
public class FashionRecommendationApp {
    private FashionRecommendationModel model;
    public FashionRecommendationApp() {
        model = new FashionRecommendationModel();
        model.trainModel("dataset");
    }
    @PostMapping("/recommend")
    public String recommend(@RequestParam("image") MultipartFile imageFile) throws IOException {
        // 读取上传的图像文件
        InputStream inputStream = new ByteArrayInputStream(imageFile.getBytes());
        BufferedImage image = ImageIO.read(inputStream);
        // 将图像转换为 ND4J 的数组格式
        org.datavec.image.loader.NativeImageLoader loader = new org.datavec.image.loader.NativeImageLoader(224, 224, 3);
        INDArray imageArray = loader.asMatrix(image);
        // 使用模型进行预测
        int prediction = model.predict(imageArray);
        // 返回预测的穿搭建议
        switch (prediction) {
            case 0:
                return "休闲装";
            case 1:
                return "正装";
            case 2:
                return "运动装";
            default:
                return "无法识别";
        }
    }
    public static void main(String[] args) {
        SpringApplication.run(FashionRecommendationApp.class, args);
    }
}

这段代码定义了一个 Spring Boot 应用程序,用于提供时尚穿搭推荐服务。在构造函数中,创建了一个FashionRecommendationModel对象,并使用数据集进行训练。在recommend方法中,处理用户上传的图像文件,将其转换为 ND4J 的数组格式,然后使用训练好的模型进行预测,最后返回预测的穿搭建议。

六、单元测试

模型训练测试

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.junit.jupiter.api.Assertions.assertEquals;
class FashionRecommendationModelTest {
    private FashionRecommendationModel model;
    @BeforeEach
    void setUp() {
        model = new FashionRecommendationModel();
    }
    @Test
    void testTrainModel() {
        model.trainModel("dataset");
        ComputationGraph trainedModel = model.getModel();
        assertNotNull(trainedModel);
    }
    @Test
    void testPredict() {
        model.trainModel("dataset");
        // 加载测试图像
        org.datavec.image.loader.NativeImageLoader loader = new org.datavec.image.loader.NativeImageLoader(224, 224, 3);
        INDArray testImage = loader.asMatrix(new File("test_image.jpg"));
        int prediction = model.predict(testImage);
        // 根据测试图像的实际穿搭类型进行断言
        assertEquals(0, prediction);
    }
}

这段代码对FashionRecommendationModel类进行了单元测试。在testTrainModel方法中,测试了模型的训练方法,确保训练后的模型不为空。在testPredict方法中,加载一个测试图像,使用训练好的模型进行预测,并根据测试图像的实际穿搭类型进行断言。

服务测试

import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.result.MockMvcResultMatchers;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import java.io.FileInputStream;
@SpringBootTest
class FashionRecommendationAppTest {
    private MockMvc mockMvc;
    @Test
    void testRecommend() throws Exception {
        FashionRecommendationApp app = new FashionRecommendationApp();
        mockMvc = MockMvcBuilders.standaloneSetup(app).build();
        // 加载测试图像
        FileInputStream fis = new FileInputStream("test_image.jpg");
        MockMultipartFile imageFile = new MockMultipartFile("image", "test_image.jpg", "image/jpeg", fis);
        // 发送 POST 请求进行测试
        mockMvc.perform(MockMvcRequestBuilders.multipart("/recommend")
                       .file(imageFile))
               .andExpect(MockMvcResultMatchers.status().isOk())
               .andExpect(MockMvcResultMatchers.content().string("休闲装"));
    }
}

这段代码对FashionRecommendationApp类进行了单元测试。在testRecommend方法中,使用MockMvc模拟发送 POST 请求,上传一个测试图像,并断言返回的穿搭建议是否正确。

七、预期输出

八、参考资料

Deeplearning4j 官方文档

Spring Boot 官方文档

积神经网络介绍

到此这篇关于Springboot 整合 Java DL4J 实现时尚穿搭推荐系统的文章就介绍到这了,更多相关Springboot Java DL4J 时尚穿搭推荐系统内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文