数据库其它

关注公众号 jb51net

关闭
首页 > 数据库 > 数据库其它 > SQLAlchemy使用UPSERT

SQLAlchemy中使用UPSERT的操作方法

作者:花酒锄作田

本文介绍了SQLite和PostgreSQL的UPSERT操作,包括语法、注意事项和返回结果,SQLite在UPSERT中不支持WHERE子句,需要使用CASE表达式或应用层过滤,感兴趣的朋友跟随小编一起看看吧

前言

SQLite 和 PostgreSQL 都支持 UPSERT 操作,即"有则更新,无则新增"。冲突列必须有唯一约束。

语法:

场景PostgreSQLSQLite说明
基本 UPSERTON CONFLICT (col) DO UPDATE SET ...ON CONFLICT(col) DO UPDATE SET ...括号位置略有不同
冲突忽略ON CONFLICT (col) DO NOTHINGON CONFLICT(col) DO NOTHING相同
引用新值EXCLUDED.colexcluded.colPostgreSQL 大写,SQLite 小写
返回结果RETURNING *RETURNING *相同
条件更新WHERE condition不支持 WHERESQLite 限制

注意事项

EXCLUDED 和 RETURNING

EXCLUDED

EXCLUDED 表示冲突时被拦截的新值。

INSERT INTO users (email, name, age)
VALUES ('test@example.com', '新名字', 30)
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name,   -- ← 引用新值 "新名字"
    age = EXCLUDED.age      -- ← 引用新值 30
场景表达式含义示例值
原表字段users.name冲突行的当前值"老名字"
新值字段EXCLUDED.name试图插入的新值"新名字"
混合计算users.age + EXCLUDED.age原值 + 新值25 + 30 = 55

示例 1:累加库存

-- 商品库存累加:原库存 100 + 新增 50 = 150
INSERT INTO products (sku, stock)
VALUES ('IPHONE15', 50)
ON CONFLICT (sku) DO UPDATE SET
    stock = products.stock + EXCLUDED.stock  -- 100 + 50
RETURNING stock;

示例 2:仅更新非空字段

-- 如果新值为 NULL,保留原值
INSERT INTO users (email, name, age)
VALUES ('test@example.com', '新名字', NULL)
ON CONFLICT (email) DO UPDATE SET
    name = COALESCE(EXCLUDED.name, users.name),  -- 新名字
    age = COALESCE(EXCLUDED.age, users.age)      -- 保留原 age

示例 3:时间戳更新

-- 更新时刷新 updated_at
INSERT INTO users (email, name)
VALUES ('test@example.com', '新名字')
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name,
    updated_at = NOW()  -- PostgreSQL
    -- updated_at = CURRENT_TIMESTAMP  -- SQLite

RETURNING

RETURNING 用于返回操作结果。在 INSERT/UPDATE/DELETE直接返回指定列,避免额外 SELECT 查询:

INSERT INTO users (email, name)
VALUES ('test@example.com', '张三')
RETURNING id, email, name, created_at;

示例 1:插入后立即获取 ID

# PostgreSQL / SQLite 3.35+
sql = text("""
    INSERT INTO users (email, name)
    VALUES (:email, :name)
    RETURNING id, email, created_at
""")
result = await session.execute(sql, {"email": "test@example.com", "name": "张三"})
user = result.mappings().first()
print(user["id"])  # 直接获取 ID

示例 2:UPSERT 后统一返回

-- 无论插入还是更新,都返回最终状态
INSERT INTO users (email, name, login_count)
VALUES ('test@example.com', '张三', 1)
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name,
    login_count = users.login_count + 1  -- 累加登录次数
RETURNING 
    id,
    email,
    name,
    login_count,
    CASE 
        WHEN xmax = 0 THEN 'inserted'  -- PostgreSQL 特有:xmax=0 表示插入
        ELSE 'updated'
    END AS action

示例 3:批量操作返回所有结果

-- PostgreSQL 支持批量 RETURNING
INSERT INTO users (email, name)
VALUES 
    ('a@example.com', 'A'),
    ('b@example.com', 'B')
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name
RETURNING id, email, name;

Python 处理批量返回:

result = await session.execute(sql)
users = [dict(row) for row in result.mappings().all()]
# [{'id': 1, 'email': 'a@example.com', 'name': 'A'}, ...]

示例:用户登录计数器

async def record_user_login(session: AsyncSession, email: str, name: str) -> dict:
    """
    用户登录计数器:
    - 新用户:插入,login_count = 1
    - 老用户:更新,login_count += 1
    - 返回最终状态 + 操作类型
    """
    sql = text("""
        INSERT INTO users (
            email, name, login_count, last_login, created_at
        ) VALUES (
            :email, :name, 1, :now, :now
        )
        ON CONFLICT (email) DO UPDATE SET
            name = EXCLUDED.name,                          -- 更新用户名
            login_count = users.login_count + 1,           -- 累加登录次数
            last_login = EXCLUDED.last_login               -- 更新最后登录时间
        RETURNING
            id,
            email,
            name,
            login_count,
            last_login,
            created_at,
            CASE 
                WHEN xmax = 0 THEN 'inserted' 
                ELSE 'updated' 
            END AS action  -- PostgreSQL 特有:区分插入/更新
    """)
    now = datetime.utcnow()
    result = await session.execute(
        sql,
        {"email": email, "name": name, "now": now}
    )
    row = result.mappings().first()
    return dict(row) if row else None
# 使用示例
user = await record_user_login(session, "test@example.com", "张三")
print(f"{user['action']} user {user['email']} with {user['login_count']} logins")
# 输出: inserted user test@example.com with 1 logins
# 或: updated user test@example.com with 5 logins

示例数据模型类

from sqlalchemy import Column, Integer, String, UniqueConstraint
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
    pass
class User(Base):
    __tablename__ = "users"
    id = Column(Integer, primary_key=True, autoincrement=True)
    email = Column(String(100), unique=True, nullable=False)  # 唯一约束
    name = Column(String(50))
    age = Column(Integer)
    balance = Column(Integer, default=0)
    __table_args__ = (
        UniqueConstraint("email", name="uq_users_email"),
    )
class Product(Base):
    __tablename__ = "products"
    id = Column(Integer, primary_key=True)
    sku = Column(String(50), unique=True, nullable=False)  # 唯一 SKU
    name = Column(String(100))
    stock = Column(Integer, default=0)
    price = Column(Integer)

ORM 方式

注意 insert 的导入路径。

基本示例

from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy import insert
async def upsert_user_orm(session: AsyncSession, user_data: dict) -> dict:
    """
    UPSERT 用户(ORM 风格)
    如果 email 冲突则更新,否则插入
    """
    # 方式 1:使用通用 insert(推荐⭐)
    # SQLAlchemy 会根据方言自动选择正确的语法
    stmt = (
        insert(User)
        .values(**user_data)
        .on_conflict_do_update(
            index_elements=["email"],  # 冲突检测列(唯一约束)
            set_={
                "name": user_data["name"],
                "age": user_data.get("age"),
                "updated_at": func.now()  # 假设有 updated_at 列
            }
        )
        .returning(User)  # 返回插入/更新后的行
    )
    result = await session.execute(stmt)
    user = result.scalar_one()
    return {
        "id": user.id,
        "email": user.email,
        "name": user.name,
        "age": user.age
    }
async def upsert_user_ignore(session: AsyncSession, user_data: dict) -> bool:
    """
    UPSERT 但冲突时忽略(DO NOTHING)
    """
    stmt = (
        insert(User)
        .values(**user_data)
        .on_conflict_do_nothing(
            index_elements=["email"]
        )
    )
    result = await session.execute(stmt)
    return result.rowcount > 0  # 返回是否插入成功

条件更新:仅更新特定字段

async def upsert_user_conditional(session: AsyncSession, user_data: dict) -> dict:
    """
    UPSERT:冲突时只更新非空字段
    """
    stmt = (
        insert(User)
        .values(**user_data)
        .on_conflict_do_update(
            index_elements=["email"],
            set_={
                "name": user_data["name"],
                # 条件:只有提供了 age 才更新
                "age": user_data.get("age", User.age),  # 保持原值
            },
            # 可选:添加 WHERE 条件
            where=User.email == user_data["email"]
        )
        .returning(User)
    )
    result = await session.execute(stmt)
    return result.mappings().first()

批量 UPSERT

async def bulk_upsert_users(session: AsyncSession, users: list[dict]) -> int:
    """
    批量 UPSERT 用户
    """
    stmt = (
        insert(User)
        .values(users)
        .on_conflict_do_update(
            index_elements=["email"],
            set_={
                "name": insert(User).excluded.name,  # 使用 excluded 表示新值
                "age": insert(User).excluded.age,
            }
        )
    )
    result = await session.execute(stmt)
    return result.rowcount

使用 EXCLUDED 引用新值

async def upsert_product_with_stock(session: AsyncSession, product_data: dict) -> dict:
    """
    UPSERT 产品:冲突时累加库存
    """
    stmt = (
        insert(Product)
        .values(**product_data)
        .on_conflict_do_update(
            index_elements=["sku"],
            set_={
                # 累加库存:原库存 + 新库存
                "stock": Product.stock + insert(Product).excluded.stock,
                # 更新其他字段
                "name": insert(Product).excluded.name,
                "price": insert(Product).excluded.price,
            }
        )
        .returning(Product)
    )
    result = await session.execute(stmt)
    return result.mappings().first()

用户服务

class UserService:
    """用户服务(支持 UPSERT)"""
    def __init__(self, session: AsyncSession):
        self.session = session
    async def create_or_update(self, email: str, name: str, age: int | None = None) -> dict:
        """创建或更新用户"""
        stmt = (
            insert(User)
            .values(
                email=email,
                name=name,
                age=age,
                created_at=datetime.utcnow()
            )
            .on_conflict_do_update(
                index_elements=["email"],
                set_={
                    "name": name,
                    "age": age,
                    "updated_at": datetime.utcnow()
                }
            )
            .returning(User)
        )
        result = await self.session.execute(stmt)
        user = result.scalar_one()
        return {
            "id": user.id,
            "email": user.email,
            "name": user.name,
            "age": user.age
        }
    async def bulk_create_or_update(self, users: list[dict]) -> int:
        """批量创建或更新"""
        stmt = (
            insert(User)
            .values(users)
            .on_conflict_do_update(
                index_elements=["email"],
                set_={
                    "name": insert(User).excluded.name,
                    "age": insert(User).excluded.age,
                    "updated_at": datetime.utcnow()
                }
            )
        )
        result = await self.session.execute(stmt)
        return result.rowcount
    async def create_if_not_exists(self, email: str, name: str) -> bool:
        """仅当不存在时创建"""
        stmt = (
            insert(User)
            .values(
                email=email,
                name=name,
                created_at=datetime.utcnow()
            )
            .on_conflict_do_nothing(
                index_elements=["email"]
            )
        )
        result = await self.session.execute(stmt)
        return result.rowcount > 0  # True = 插入成功,False = 已存在

原生 SQL

基本示例

PostgreSQL

async def upsert_user_pg(session: AsyncSession, user_data: dict) -> dict | None:
    """
    PostgreSQL 原生 UPSERT
    """
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT (email) DO UPDATE  -- 冲突列
        SET 
            name = EXCLUDED.name,      -- EXCLUDED 表示新插入的值
            age = EXCLUDED.age,
            updated_at = NOW()
        RETURNING id, email, name, age
    """)
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data["name"],
            "age": user_data.get("age"),
            "created_at": datetime.utcnow()
        }
    )
    row = result.mappings().first()
    return dict(row) if row else None

SQLite

async def upsert_user_sqlite(session: AsyncSession, user_data: dict) -> dict | None:
    """
    SQLite 原生 UPSERT(语法与 PostgreSQL 几乎相同)
    """
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT(email) DO UPDATE SET  -- SQLite 语法稍有不同
            name = excluded.name,
            age = excluded.age,
            updated_at = CURRENT_TIMESTAMP
        RETURNING id, email, name, age
    """)
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data["name"],
            "age": user_data.get("age"),
            "created_at": datetime.utcnow()
        }
    )
    row = result.mappings().first()
    return dict(row) if row else None

冲突时忽略

async def insert_or_ignore_user(session: AsyncSession, user_data: dict) -> bool:
    """
    插入用户,如果冲突则忽略
    """
    # PostgreSQL
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT (email) DO NOTHING
    """)
    # SQLite(语法相同)
    # sql = text("""
    #     INSERT INTO users (email, name, age, created_at)
    #     VALUES (:email, :name, :age, :created_at)
    #     ON CONFLICT(email) DO NOTHING
    # """)
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data["name"],
            "age": user_data.get("age"),
            "created_at": datetime.utcnow()
        }
    )
    return result.rowcount > 0  # 返回是否插入成功

批量 UPSERT

async def bulk_upsert_products(session: AsyncSession, products: list[dict]) -> int:
    """
    批量 UPSERT 产品(原生 SQL)
    """
    # PostgreSQL
    sql = text("""
        INSERT INTO products (sku, name, stock, price, created_at)
        VALUES (
            :sku, :name, :stock, :price, :created_at
        )
        ON CONFLICT (sku) DO UPDATE SET
            name = EXCLUDED.name,
            stock = products.stock + EXCLUDED.stock,  -- 累加库存
            price = EXCLUDED.price,
            updated_at = NOW()
    """)
    # 批量执行
    for product in products:
        await session.execute(
            sql,
            {
                "sku": product["sku"],
                "name": product["name"],
                "stock": product.get("stock", 0),
                "price": product.get("price", 0),
                "created_at": datetime.utcnow()
            }
        )
    return len(products)

部分更新 + 条件判断

async def upsert_user_smart(session: AsyncSession, user_data: dict) -> dict | None:
    """
    智能 UPSERT:
    - 如果提供了 age,才更新 age
    - 如果提供了 name,才更新 name
    - 更新 updated_at
    """
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT (email) DO UPDATE SET
            name = COALESCE(:name, users.name),  -- 如果新值为 NULL,保持原值
            age = COALESCE(:age, users.age),
            updated_at = NOW()
        RETURNING id, email, name, age, updated_at
    """)
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data.get("name"),  # 可能为 None
            "age": user_data.get("age"),    # 可能为 None
            "created_at": datetime.utcnow()
        }
    )
    row = result.mappings().first()
    return dict(row) if row else None

用户注册/登录:存在则更新最后登录时间

async def register_or_login(session: AsyncSession, email: str, name: str) -> dict:
    """
    用户注册或登录:
    - 新用户:插入
    - 老用户:更新最后登录时间
    """
    sql = text("""
        INSERT INTO users (email, name, last_login, created_at)
        VALUES (:email, :name, :now, :now)
        ON CONFLICT (email) DO UPDATE SET
            last_login = EXCLUDED.last_login,
            name = EXCLUDED.name  -- 可选:更新用户名
        RETURNING id, email, name, last_login, created_at
    """)
    now = datetime.utcnow()
    result = await session.execute(
        sql,
        {"email": email, "name": name, "now": now}
    )
    return dict(result.mappings().first())

库存累加

async def add_product_stock(session: AsyncSession, sku: str, quantity: int) -> bool:
    """
    增加商品库存:
    - 商品不存在:插入
    - 商品存在:累加库存
    """
    sql = text("""
        INSERT INTO products (sku, stock, created_at)
        VALUES (:sku, :quantity, :now)
        ON CONFLICT (sku) DO UPDATE SET
            stock = products.stock + EXCLUDED.stock,
            updated_at = NOW()
    """)
    result = await session.execute(
        sql,
        {
            "sku": sku,
            "quantity": quantity,
            "now": datetime.utcnow()
        }
    )
    return result.rowcount > 0

用户积分累加

async def add_user_points(session: AsyncSession, user_id: int, points: int) -> dict | None:
    """
    增加用户积分(累加)
    """
    sql = text("""
        INSERT INTO user_points (user_id, points, created_at)
        VALUES (:user_id, :points, :now)
        ON CONFLICT (user_id) DO UPDATE SET
            points = user_points.points + EXCLUDED.points,
            updated_at = NOW()
        RETURNING user_id, points
    """)
    result = await session.execute(
        sql,
        {
            "user_id": user_id,
            "points": points,
            "now": datetime.utcnow()
        }
    )
    row = result.mappings().first()
    return dict(row) if row else None

标签计数

存在则 +1,不存在则创建:

async def increment_tag_count(session: AsyncSession, tag_name: str) -> int:
    """
    标签计数:
    - 标签不存在:插入 count=1
    - 标签存在:count += 1
    """
    sql = text("""
        INSERT INTO tags (name, count, created_at)
        VALUES (:name, 1, :now)
        ON CONFLICT (name) DO UPDATE SET
            count = tags.count + 1,
            updated_at = NOW()
        RETURNING count
    """)
    result = await session.execute(
        sql,
        {"name": tag_name, "now": datetime.utcnow()}
    )
    return result.scalar() or 0

到此这篇关于SQLAlchemy中使用UPSERT的操作方法的文章就介绍到这了,更多相关SQLAlchemy使用UPSERT内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文