Redis

关注公众号 jb51net

关闭
首页 > 数据库 > Redis > Redis lua读写锁

利用Redis lua实现高效读写锁的代码实例

作者:Goland猫

这篇文章给大家介绍了如何利用Redis lua实现高效的读写锁,读写锁的好处就是能帮助客户读到的数据一定是最新的,写锁是排他锁,而读锁是一个共享锁,需要的朋友可以参考下

前言

读写锁的好处就是能帮助客户读到的数据一定是最新的,写锁是排他锁,而读锁是一个共享锁,如果写锁一直存在,那么读取数据就要一直等待,直到写入数据完成才能看到,保证了数据的一致性

一、为什么使用Lua

Lua脚本是高并发、高性能的必备脚本语言, 大部分的开源框架(如:redission)中的分布式锁组件,都是用纯lua脚本实现的。

那么,为什么要使用Lua语言来实现分布式锁呢?我们从一个案例看起:

所以,只有确保判断锁和删除锁是一步操作时,才能避免上面的问题,才能确保原子性。

其实很简单,首先获取锁对应的value值,检查是否与requestId相等,如果相等则删除锁(解锁)。虽然看似做了两件事,但是却只有一个完整的原子操作。

第一行代码,我们写了一个简单的 Lua 脚本代码; 第二行代码,我们将Lua代码传到 edis.eval()方法里,并使参数 KEYS[1] 赋值为 lockKey,ARGV[1] 赋值为 requestId,eval() 方法是将Lua代码交给 Redis 服务端执行。

二、执行流程

加锁和删除锁的操作,使用纯 Lua 进行封装,保障其执行时候的原子性。

基于纯Lua脚本实现分布式锁的执行流程,大致如下:

三、代码详解

lua\lock.lua

-- KEYS = [LOCK_KEY, LOCK_INTENT]
-- ARGV = [LOCK_ID, TTL]
local t = redis.call('TYPE', KEYS[1])["ok"]
if t == "string" then
   return redis.call('PTTL', KEYS[1])
end

if redis.call("EXISTS", KEYS[2]) == 1 then
   return redis.call('PTTL', KEYS[2])
end

redis.call('SADD', KEYS[1], ARGV[1])
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return nil

它首先尝试通过 SET 命令将 LOCK_KEY 存储到 Redis 中,如果设置失败,则表示锁已被其他进程占用,返回锁的剩余过期时间。如果设置成功,则删除 LOCK_INTENT 键,表示锁已成功获取

lua\refresh.lua

-- KEYS = [LOCK_KEY]
-- ARGV = [LOCK_ID, TTL]
local t = redis.call('TYPE', KEYS[1])["ok"]
if (t == "string" and redis.call('GET', KEYS[1]) ~= ARGV[1]) or
        (t == "set" and redis.call('SISMEMBER', KEYS[1], ARGV[1]) == 0) or
        (t == "none") then
    return 0
end

return redis.call('PEXPIRE', KEYS[1], ARGV[2])

lua\rlock.lua

-- KEYS = [LOCK_KEY, LOCK_INTENT]
-- ARGV = [LOCK_ID, TTL]
local t = redis.call('TYPE', KEYS[1])["ok"]
if t == "string" then
   return redis.call('PTTL', KEYS[1])
end

if redis.call("EXISTS", KEYS[2]) == 1 then
   return redis.call('PTTL', KEYS[2])
end

redis.call('SADD', KEYS[1], ARGV[1])
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return nil

lua\unlock.lua

-- KEYS = [LOCK_KEY]
-- ARGV = [LOCK_ID]
local t = redis.call('TYPE', KEYS[1])["ok"]
if t == "string" and redis.call('GET', KEYS[1]) == ARGV[1] then
    return redis.call('DEL', KEYS[1])
elseif t == "set" and redis.call('SISMEMBER', KEYS[1], ARGV[1]) == 1 then
    redis.call('SREM', KEYS[1], ARGV[1])
    if redis.call('SCARD', KEYS[1]) == 0 then
        return redis.call('DEL', KEYS[1])
    end
end
return 1

写优先还是读优先?

写锁会阻塞读锁,所以是写优先

写锁是如何阻塞写锁的?

如果当前的写锁已经被占用,其他写锁的获取请求会被阻塞,因为在释放锁的逻辑中,会先判断锁的类型,如果是写锁,则会判断当前锁的值是否符合预期,从而判断能否删除该锁。

读锁与读锁之间互斥吗?

对于读锁而言,多个读锁之间是可以并发持有的,因此读锁之间默认是不会互斥的,可以同时执行读操作。

写锁会有被饿死的情况吗?

写优先锁可以保证写线程不会饿死,但是如果一直有写线程获取写锁,读线程也会被「饿死」。

既然不管优先读锁还是写锁,对方可能会出现饿死问题,那么我们就不偏袒任何一方,搞个「公平读写锁」。

公平读写锁比较简单的一种方式是:用队列把获取锁的线程排队,不管是写线程还是读线程都按照先进先出的原则加锁即可,这样读线程仍然可以并发,也不会出现「饥饿」的现象。

抽象lock类

import (
	"context"
	"errors"
	"time"

	"github.com/redis/go-redis/v9"
)

var _ context.Context = (*Lock)(nil)

// Lock represents a lock with context.
type Lock struct {
	redis  redis.Scripter
	id     string
	ttl    time.Duration
	key    string
	log    LogFunc
	ctx    context.Context
	cancel context.CancelFunc
}

// ID returns the id value set by the lock.
func (l *Lock) ID() string {
	return l.id
}

// Key returns the key value set by the lock.
func (l *Lock) Key() string {
	return l.key
}

func (l *Lock) Deadline() (deadline time.Time, ok bool) {
	return l.ctx.Deadline()
}

func (l *Lock) Done() <-chan struct{} {
	return l.ctx.Done()
}

func (l *Lock) Err() error {
	return l.ctx.Err()
}

func (l *Lock) Value(key any) any {
	return l.ctx.Value(key)
}

// Unlock unlocks.
func (l *Lock) Unlock() {
	l.cancel()
	_, err := scriptUnlock.Run(context.Background(), l.redis, []string{l.key}, l.id).Result()
	if err != nil {
		l.log("[ERROR] unlock %q %s: %v", l.key, l.id, err)
	}
}

func (l *Lock) refreshTTL(left time.Time) {
	defer l.cancel()

	refresh := l.updateTTL()
	for {
		diff := time.Since(left)
		select {

		case <-l.ctx.Done():
			return

		case <-time.After(-diff): // cant refresh
			return

		case <-time.After(refresh):
			status, err := scriptRefresh.Run(l.ctx, l.redis, []string{l.key}, l.id, l.ttl.Milliseconds()).Int()
			if err != nil {
				if errors.Is(err, context.Canceled) {
					return
				}

				refresh = refreshTimeout
				l.log("[ERROR] refresh key %q %s: %v", l.key, l.id, err)
				continue
			}

			left = l.leftTTL()
			refresh = l.updateTTL()
			if status == 0 {
				l.log("[ERROR] refresh key %q %s already expired", l.key, l.id)
				return
			}
		}
	}
}

func (l *Lock) leftTTL() time.Time {
	return time.Now().Add(l.ttl)
}

func (l *Lock) updateTTL() time.Duration {
	return l.ttl / 2
}

为什么需要为什么l.ttl / 2

这是为了实现锁的自动续约。通过定期刷新锁的过期时间,可以确保锁在使用过程中不会过期而被意外释放。

这种做法可以在以下情况下带来一些好处:

Options

package redismutex

import (
	"context"
	"log"
	"os"
	"sync"
	"time"
)

const (
	lenBytesID     = 16
	refreshTimeout = time.Millisecond * 500
	defaultKeyTTL  = time.Second * 4
)

var (
	globalMx  sync.RWMutex
	globalLog = func() LogFunc {
		l := log.New(os.Stderr, "redismutex: ", log.LstdFlags)
		return func(format string, v ...any) {
			l.Printf(format, v...)
		}
	}()
)

// LogFunc type is an adapter to allow the use of ordinary functions as LogFunc.
type LogFunc func(format string, v ...any)

// NopLog logger does nothing
var NopLog = LogFunc(func(string, ...any) {})

// SetLog sets the logger.
func SetLog(l LogFunc) {
	globalMx.Lock()
	defer globalMx.Unlock()

	if l != nil {
		globalLog = l
	}
}

// MutexOption is the option for the mutex.
type MutexOption func(*mutexOptions)

type mutexOptions struct {
	name       string
	ttl        time.Duration
	lockIntent bool
	log        LogFunc
}

// WithTTL sets the TTL of the mutex.
func WithTTL(ttl time.Duration) MutexOption {
	return func(o *mutexOptions) {
		if ttl >= time.Second*2 {
			o.ttl = ttl
		}
	}
}

// WithLockIntent sets the lock intent.
func WithLockIntent() MutexOption {
	return func(o *mutexOptions) {
		o.lockIntent = true
	}
}

// LockOption is the option for the lock.
type LockOption func(*lockOptions)

type lockOptions struct {
	ctx              context.Context
	key              string
	lockIntentKey    string
	enableLockIntent int
	ttl              time.Duration
	log              LogFunc
}

func newLockOptions(m mutexOptions, opt ...LockOption) lockOptions {
	opts := lockOptions{
		ctx:              context.Background(),
		key:              m.name,
		enableLockIntent: boolToInt(m.lockIntent),
		ttl:              m.ttl,
		log:              m.log,
	}

	for _, o := range opt {
		o(&opts)
	}

	opts.lockIntentKey = lockIntentKey(opts.key)
	return opts
}

// WithKey sets the key of the lock.
func WithKey(key string) LockOption {
	return func(o *lockOptions) {
		if key != "" {
			o.key += ":" + key
		}
	}
}

// WithContext sets the context of the lock.
func WithContext(ctx context.Context) LockOption {
	return func(o *lockOptions) {
		if ctx != nil {
			o.ctx = ctx
		}
	}
}

func boolToInt(b bool) int {
	if b {
		return 1
	}
	return 0
}

func lockIntentKey(key string) string {
	return key + ":lock-intent"
}

可以通过设置选项来控制互斥锁的行为和属性,如生存时间、锁意图、上下文等。还提供了一些实用函数和类型,用于管理互斥锁和生成选项

redismutex

// Package redismutex provides a distributed rw mutex.
package redismutex

import (
	"context"
	"crypto/rand"
	"embed"
	"encoding/hex"
	"errors"
	"sync"
	"time"

	"github.com/redis/go-redis/v9"
)

var ErrLock = errors.New("redismutex: lock not obtained")

var (
	//go:embed lua
	lua embed.FS

	scriptRLock   *redis.Script
	scriptLock    *redis.Script
	scriptRefresh *redis.Script
	scriptUnlock  *redis.Script
)

func init() {
	scriptRLock = redis.NewScript(mustReadFile("rlock.lua"))
	scriptLock = redis.NewScript(mustReadFile("lock.lua"))
	scriptRefresh = redis.NewScript(mustReadFile("refresh.lua"))
	scriptUnlock = redis.NewScript(mustReadFile("unlock.lua"))
}

// A RWMutex is a distributed mutual exclusion lock.
type RWMutex struct {
	redis redis.Scripter
	opts  mutexOptions

	id struct {
		sync.Mutex
		buf []byte
	}
}

// NewMutex creates a new distributed mutex.
func NewMutex(rc redis.Scripter, name string, opt ...MutexOption) *RWMutex {
	globalMx.RLock()
	defer globalMx.RUnlock()

	opts := mutexOptions{
		name: name,
		ttl:  defaultKeyTTL,
		log:  globalLog,
	}

	for _, o := range opt {
		o(&opts)
	}

	rw := &RWMutex{
		redis: rc,
		opts:  opts,
	}
	rw.id.buf = make([]byte, lenBytesID)
	return rw
}

// TryRLock tries to lock for reading and reports whether it succeeded.
func (m *RWMutex) TryRLock(opt ...LockOption) (*Lock, bool) {
	opts := newLockOptions(m.opts, opt...)
	ctx, _, err := m.rlock(opts)
	if err != nil {
		if !errors.Is(err, ErrLock) {
			m.opts.log("[ERROR] try-read-lock key %q: %v", opts.key, err)
		}
		return nil, false
	}
	return ctx, true
}

// RLock locks for reading.
func (m *RWMutex) RLock(opt ...LockOption) (*Lock, bool) {
	opts := newLockOptions(m.opts, opt...)
	ctx, ttl, err := m.rlock(opts)
	if err == nil {
		return ctx, true
	}

	if !errors.Is(err, ErrLock) {
		m.opts.log("[ERROR] read-lock key %q: %v", opts.key, err)
		return nil, false
	}

	for {
		select {
		case <-opts.ctx.Done():
			m.opts.log("[ERROR] read-lock key %q: %v", opts.key, opts.ctx.Err())
			return nil, false

		case <-time.After(ttl):
			ctx, ttl, err = m.rlock(opts)
			if err == nil {
				return ctx, true
			}

			if !errors.Is(err, ErrLock) {
				m.opts.log("[ERROR] read-lock key %q: %v", opts.key, err)
				return nil, false
			}
			continue
		}
	}
}

// TryLock tries to lock for writing and reports whether it succeeded.
func (m *RWMutex) TryLock(opt ...LockOption) (*Lock, bool) {
	opts := newLockOptions(m.opts, opt...)
	opts.enableLockIntent = 0 // force disable lock intent

	ctx, _, err := m.lock(opts)
	if err != nil {
		if !errors.Is(err, ErrLock) {
			m.opts.log("[ERROR] try-lock key %q: %v", opts.key, err)
		}
		return nil, false
	}
	return ctx, true
}

// Lock locks for writing.
func (m *RWMutex) Lock(opt ...LockOption) (*Lock, bool) {
	opts := newLockOptions(m.opts, opt...)
	ctx, ttl, err := m.lock(opts)

	if err == nil {
		return ctx, true
	}

	if !errors.Is(err, ErrLock) {
		m.opts.log("[ERROR] lock key %q: %v", opts.key, err)
		return nil, false
	}

	for {
		select {
		case <-opts.ctx.Done():
			m.opts.log("[ERROR] lock key %q: %v", opts.key, opts.ctx.Err())
			return nil, false

		case <-time.After(ttl):
			ctx, ttl, err = m.lock(opts)
			if err == nil {
				return ctx, true
			}

			if !errors.Is(err, ErrLock) {
				m.opts.log("[ERROR] lock key %q: %v", opts.key, err)
				return nil, false
			}
			continue
		}
	}
}

func (m *RWMutex) lock(opts lockOptions) (*Lock, time.Duration, error) {
	id, err := m.randomID()
	if err != nil {
		return nil, 0, err
	}

	pTTL, err := scriptLock.Run(opts.ctx, m.redis, []string{opts.key, opts.lockIntentKey}, id, opts.ttl.Milliseconds(), opts.enableLockIntent).Result()
	leftTTL := time.Now().Add(opts.ttl)
	if err == nil {
		return nil, time.Duration(pTTL.(int64)) * time.Millisecond, ErrLock
	}

	if err != redis.Nil {
		return nil, 0, err
	}

	ctx, cancel := context.WithCancel(opts.ctx)
	lock := &Lock{
		redis:  m.redis,
		id:     id,
		ttl:    opts.ttl,
		key:    opts.key,
		log:    opts.log,
		ctx:    ctx,
		cancel: cancel,
	}
	go lock.refreshTTL(leftTTL)
	return lock, 0, nil
}

func (m *RWMutex) rlock(opts lockOptions) (*Lock, time.Duration, error) {
	id, err := m.randomID()
	if err != nil {
		return nil, 0, err
	}

	pTTL, err := scriptRLock.Run(opts.ctx, m.redis, []string{opts.key, opts.lockIntentKey}, id, opts.ttl.Milliseconds()).Result()
	leftTTL := time.Now().Add(opts.ttl)
	if err == nil {
		return nil, time.Duration(pTTL.(int64)) * time.Millisecond, ErrLock
	}

	if err != redis.Nil {
		return nil, 0, err
	}

	ctx, cancel := context.WithCancel(opts.ctx)
	lock := &Lock{
		redis:  m.redis,
		id:     id,
		ttl:    opts.ttl,
		key:    opts.key,
		log:    opts.log,
		ctx:    ctx,
		cancel: cancel,
	}
	go lock.refreshTTL(leftTTL)
	return lock, 0, nil
}

// randomID generates a random hex string with 16 bytes.
func (m *RWMutex) randomID() (string, error) {
	m.id.Lock()
	defer m.id.Unlock()

	_, err := rand.Read(m.id.buf)
	if err != nil {
		return "", err
	}
	return hex.EncodeToString(m.id.buf), nil
}

func mustReadFile(filename string) string {
	b, err := lua.ReadFile("lua/" + filename)
	if err != nil {
		panic(err)
	}
	return string(b)
}

测试用例

package redismutex

import (
	"context"
	"errors"
	"log"
	"strings"
	"testing"
	"time"

	"github.com/redis/go-redis/v9"
)

func init() {
	SetLog(func(format string, a ...any) {
		if strings.HasPrefix(format, "[ERROR]") {
			log.Fatalf(format, a...)
		}
	})
}

func TestMutex(t *testing.T) {
	t.Parallel()

	const lockKey = "mutex"
	rc := redis.NewClient(redisOpts())
	prep(t, rc, lockKey)

	mx := NewMutex(rc, lockKey)
	lock, ok := mx.Lock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	defer lock.Unlock()

	assertTTL(t, rc, lockKey, defaultKeyTTL)

	// try again
	_, ok = mx.TryLock()
	if exp, got := false, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}

	_, ok = mx.TryRLock()
	if exp, got := false, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}

	// manually unlock
	lock.Unlock()

	// lock again
	lock, ok = mx.Lock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	defer lock.Unlock()
}

func TestRWMutex(t *testing.T) {
	t.Parallel()

	const lockKey = "rw_mutex"
	rc := redis.NewClient(redisOpts())
	prep(t, rc, lockKey)

	mx := NewMutex(rc, lockKey)
	lock, ok := mx.RLock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	defer lock.Unlock()

	assertTTL(t, rc, lockKey, defaultKeyTTL)

	// try again
	_, ok = mx.TryLock()
	if exp, got := false, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}

	// try rlock
	rlock, ok := mx.TryRLock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	rlock.Unlock()

	// manually unlock
	lock.Unlock()

	// lock again
	lock, ok = mx.Lock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	defer lock.Unlock()
}

func TestRWMutex_LockIntent(t *testing.T) {
	t.Parallel()

	const lockKey = "lock_intent_mutex"
	rc := redis.NewClient(redisOpts())
	prep(t, rc, lockKey)

	mx := NewMutex(rc, lockKey, WithLockIntent())
	lock, ok := mx.RLock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	defer lock.Unlock()

	// mark lock intent
	_, _, err := mx.lock(newLockOptions(mx.opts))
	if exp, got := ErrLock, err; !errors.Is(got, exp) {
		t.Fatalf("exp %v, got %v", exp, got)
	}

	// try rlock
	_, ok = mx.TryRLock()
	if exp, got := false, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}

	// manually unlock
	lock.Unlock()

	// lock write
	lock, ok = mx.Lock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	lock.Unlock() // remove lock intent

	// lock again
	lock, ok = mx.RLock()
	if exp, got := true, ok; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
	defer lock.Unlock()
}

func TestRWMutex_ID(t *testing.T) {
	t.Parallel()

	rw := &RWMutex{}
	rw.id.buf = make([]byte, lenBytesID)
	id, _ := rw.randomID()
	if exp, got := 32, len(id); exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}
}

func prep(t *testing.T, rc *redis.Client, key string) {
	t.Cleanup(func() {
		for _, v := range []string{key, lockIntentKey(key)} {
			if err := rc.Del(context.Background(), v).Err(); err != nil {
				t.Fatal(err)
			}
		}

		if err := rc.Close(); err != nil {
			t.Fatal(err)
		}
	})
}

func assertTTL(t *testing.T, rc *redis.Client, key string, exp time.Duration) {
	t.Helper()

	got, err := rc.TTL(context.Background(), key).Result()
	if exp, got := (any)(nil), err; exp != got {
		t.Fatalf("exp %v, got %v", exp, got)
	}

	delta := got - exp
	if delta < 0 {
		delta = 1 - delta
	}

	if delta > time.Second {
		t.Fatalf("exp ~%v, got %v", exp, got)
	}
}

func redisOpts() *redis.Options {
	return &redis.Options{
		Network: "tcp",
		Addr:    "0.0.0.0:6379",
		DB:      9,
	}
}

以上就是利用Redis lua实现高效读写锁的代码实例的详细内容,更多关于Redis lua读写锁的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文