PostgreSQL

关注公众号 jb51net

关闭
首页 > 数据库 > PostgreSQL > PostgreSQL COPY批量数据写入

PostgreSQL使用COPY协议高效批量数据写入的实战指南

作者:ezreal_pan

这篇文章主要介绍了PostgreSQL的COPY协议,这是一种高效批量数据导入导出的二进制协议,适用于需要高效写入大量数据的场景,COPY协议通过流式处理、事务安全和无参数限制等优势,显著提升了数据写入性能,并结合事务管理保证了数据一致性,需要的朋友可以参考下

问题背景

在开发过程中,我们经常会遇到需要批量写入大量数据到 PostgreSQL 数据库的场景。当使用传统的参数化插入语句时,可能会遇到如下错误:

pq: got 86575 parameters but PostgreSQL only supports 65535 parameters

这是因为 PostgreSQL 对单个查询的参数数量有限制(通常为 65535)。传统的解决方案是进行数据分片,分批写入数据库。但这种方法存在以下问题:

COPY 协议解决方案

COPY 协议简介

PostgreSQL 的 COPY 协议是专门为高效批量数据操作设计的二进制协议,具有以下优势:

  1. 高性能:避免了 SQL 解析开销,直接使用二进制格式传输数据
  2. 低内存占用:流式处理,不需要在内存中构建庞大的 SQL 语句
  3. 事务安全:可以在事务中执行,保证数据一致性
  4. 无参数限制:不受 PostgreSQL 参数数量限制

二进制协议原理

COPY 协议使用 PostgreSQL 的前后端协议进行数据传输,其工作流程如下:

  1. 启动 COPY 模式:客户端发送 COPY FROM STDIN 命令
  2. 数据传输:使用二进制格式按行发送数据
  3. 结束传输:发送特定的结束标记
  4. 确认完成:服务器返回处理结果

二进制格式避免了文本解析的开销,直接使用网络字节序传输数据,大大提高了传输效率。

实战实现

依赖库

import (
    "github.com/lib/pq"
    "gorm.io/gorm"
)

核心实现代码

// BatchCreate 批量创建消息接收者记录 - 使用 COPY 协议
func (r *receiverRepo) BatchCreate(ctx context.Context, db *gorm.DB, data []*define.WecomMsgReceiver) (rowsAffected int64, err error) {
    db = r.WithTrace(ctx, db)
    db = db.Table(r.TableName())

    if len(data) == 0 {
        return 0, nil
    }

    // 过滤掉 nil 的数据
    validData := make([]*define.WecomMsgReceiver, 0, len(data))
    for _, item := range data {
        if item != nil {
            validData = append(validData, item)
        }
    }
    if len(validData) == 0 {
        return 0, nil
    }

    // 获取底层 sql.DB
    sqlDB := db.DB()
    
    // 开始事务
    tx, err := sqlDB.BeginTx(ctx, nil)
    if err != nil {
        return 0, fmt.Errorf("开始事务失败:%+v", err)
    }
    defer func() {
        if err != nil {
            tx.Rollback()
        }
    }()

    // 创建 COPY writer
    stmt, err := tx.Prepare(pq.CopyIn(r.TableName(), "send_log_id", "user_id", "status", "created_at", "updated_at"))
    if err != nil {
        return 0, fmt.Errorf("准备 COPY 语句失败:%+v", err)
    }
    defer stmt.Close()

    // 批量写入数据
    for _, item := range validData {
        _, err = stmt.Exec(item.SendLogID, item.UserID, item.Status, item.CreatedAt, item.UpdatedAt)
        if err != nil {
            return 0, fmt.Errorf("写入数据失败:%+v", err)
        }
    }

    // 执行 COPY
    _, err := stmt.Exec()
    if err != nil {
        return 0, fmt.Errorf("执行 COPY 失败:%+v", err)
    }

    // 提交事务
    if err = tx.Commit(); err != nil {
        return 0, fmt.Errorf("提交事务失败:%+v", err)
    }
    
    rowsAffected = int64(len(validData))
    return rowsAffected, nil
}

代码说明

  1. 数据验证:首先过滤掉 nil 数据,确保数据有效性 
  2. 事务管理:使用事务确保数据一致性,出错时自动回滚
  3. COPY 准备:通过 pq.CopyIn 准备 COPY 语句,指定表名和列名
  4. 批量写入:遍历数据并执行 Exec,但此时数据还在客户端缓冲区
  5. 最终执行:调用 stmt.Exec() 真正将数据发送到服务器
  6. 事务提交:提交事务,完成批量写入

完整测试用例

// 设置测试数据库
func setupTestDB() (*gorm.DB, error) {
    ctx := context.Background()
    postgres, err := infrastructure.DialPostgres(ctx, infrastructure.PostgresConfig{
        Host:     "host",
        Port:     5432,
        Username: "postgres",
        Password: "xxxxx",
        Database: "xxxxx",
    })
    if err != nil {
        return nil, err
    }

    return postgres, nil
}

func setupLogger() factory.LogFactory {
    logger, _ := factory.NewJsonFactory(factory.NewLevel("info"), factory.NewZapOption(factory.AddCallerSkip(0)))
    return logger
}

func TestReceiverRepo_BatchCreate(t *testing.T) {
    db, err := setupTestDB()
    require.NoError(t, err)
    defer db.Close()

    // 创建日志工厂
    logger := setupLogger()

    // 创建 repository 实例
    repo := NewReceiverRepository(db, logger)

    // 准备测试数据 - 20000 条记录,使用负的 send_log_id 避免污染数据
    testData := make([]*define.WecomMsgReceiver, 0, 20000)
    now := time.Now()
    negativeSendLogID := int64(-100000) // 使用负的 send_log_id

    for i := 0; i < 20000; i++ {
        testData = append(testData, &define.WecomMsgReceiver{
            SendLogID: negativeSendLogID,
            UserID:    "test_user_" + fmt.Sprint(i),
            Status:    1,
            CreatedAt: now,
            UpdatedAt: now,
        })
    }

    ctx := context.Background()

    // 执行批量插入
    rowsAffected, err := repo.BatchCreate(ctx, db, testData)

    // 验证结果
    assert.NoError(t, err)
    assert.Equal(t, int64(20000), rowsAffected)

    // 验证数据是否正确插入
    var count int64
    query := "SELECT COUNT(*) FROM wecom_msg_receiver WHERE send_log_id < 0 AND send_log_id >= ?"
    err = db.Raw(query, negativeSendLogID).Count(&count).Error
    assert.NoError(t, err)
    assert.Equal(t, int64(20000), count)

    // 清理测试数据
    deleteQuery := "DELETE FROM wecom_msg_receiver WHERE send_log_id < 0 AND send_log_id >= ?"
    result := db.Exec(deleteQuery, negativeSendLogID)
    assert.NoError(t, result.Error)
    assert.Equal(t, int64(20000), result.RowsAffected)

    // 验证清理是否成功
    err = db.Raw(query, negativeSendLogID).Count(&count).Error
    assert.NoError(t, err)
    assert.Equal(t, int64(0), count)
}

性能对比

在实际测试中,COPY 协议相比传统分批插入有显著性能提升:

方案20000 条数据耗时内存占用网络请求次数
传统分批插入~15 秒多次
COPY 协议~2 秒1 次

注意事项

  1. 错误处理:COPY 协议中某行数据错误可能导致整个批量操作失败
  2. 数据类型:确保 Go 数据类型与 PostgreSQL 列类型匹配
  3. 连接池:长时间运行的 COPY 操作会占用数据库连接
  4. 超时设置:对于大数据量,需要适当调整上下文超时时间

总结

通过使用 PostgreSQL 的 COPY 协议,我们成功解决了批量写入时的参数数量限制问题,同时大幅提升了性能。这种方法特别适合数据迁移、日志批量处理等需要高效写入大量数据的场景。

COPY协议结合事务管理,既保证了数据一致性,又能提供了接近原生的写入性能,是PostgreSQL批量数据操作的优选方案。

以上就是PostgreSQL使用COPY协议高效批量数据写入的实战指南的详细内容,更多关于PostgreSQL COPY批量数据写入的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文