Golang

关注公众号 jb51net

关闭
首页 > 脚本专栏 > Golang > go websocket中间件

一文带你使用golang手撸一个websocket中间件

作者:47笔记

这篇文章主要为大家详细介绍了如何使用golang手撸一个websocket中间件,文中的示例代码讲解详细,具有一定的借鉴价值,感兴趣的小伙伴可以参考一下

序言

【1】:少年,你的项目要不要用到websocket呀?

(你说通讯用来干嘛?在线聊天、消息推送、扫码登录、物联网设备管理、游戏、弹幕等等这些需要立即得到反馈但又不想给服务器增加太多负担的业务场景都需要用websocket)

【2】:为什么要使用websocket?

我们先来模拟一个简单的扫码登录网页的场景,网页端在生成的二维码后要怎么知道用户有没有用手机扫描这个二维码呢?在传统的项目中最长的方式就是不停的去请求后端接口问 “用户扫了没?用户扫了没?用户扫了没?(一直往复)”,直到用户扫完或者关闭了网页。

这种方式就是最常见的 长轮训(我最开始学写代码的时候也是用这种方式),这种方式是最简单的,但是也与之相对应的问题也很明显,占用太多后台资源 and 感官延迟,那有没有一种别的方案能不占资源又快的方式呢?

在类似这种需求下 websocket 诞生了,今天这里我们不谈那些枯燥的理论知识,只玩实操!

<!DOCTYPE html>
<html>
<head>
</head>
<body>
<script>

  var ws = new WebSocket("ws://127.0.0.1:8282");

  ws.onmessage = function(event) {
    
      console.log(event.data)
  };

  ws.onclose = function(event) {

      console.log('ws连接已关闭')
  }


</script>
</body>
</html>

上面是一个最简单的 websocket 连接代码

嗯~,到这里,你已经能正常的连接到服务器上并等待 服务器主动 给你发送消息了。

好的,那现在客户端准备好了,服务端呢,别急,现在我们从 0 开始一起手撸一个websocket服务端

服务端设计

开始搭建服务端前,我们必须先思考一下它架构方式。

【1】:为什么是用golang来开发服务端,用其他语言不行吗?

当然是可以的!在文章的标题中我有提到这个用golang来开发,是因为其他语言都有现成可以用的socket中间件,但是golang似乎还没有,那我们就来自己撸一个吧!

【2】:服务端定位

开始动手写代码前应该提前思考这么一个问题,这个websocket服务应该以一种怎样的方式存在于项目中呢?

(嵌套组件 OR 独立中间件)

我的想法是中间件,出于以下原因考虑:

【3】:架构设计

上面我有提到,这个websocket服务端的两个主要功能是 管理连接和推送消息

好的,那首先围绕第一个问题,如何管理连接?当服务端出现N多个连接时要怎么知道谁是谁,消息应该推给谁?

写过php的同学应该知道workerman这个中间件,在workerman中有三个非常重要的概念,clientusergroup,其实就是 分类管理,下面我分别解释一下

在看下面的内容前,大家一定要先消化了解这三个概念

还有另外一个问题,怎么让websocket服务不侵入业务代码

这里我大概画了一张草图,三者之间的关系可以这样理解

掉头发时间

在博客我只展示一点点代码哈,其他的都已经完全开源到github了,各位看官需要的话自取哈

我们要使用golang来实现socket服务,自然离不开 github.com/gorilla/websocket 这个核心库啦!

这里不得不说,golang的生态还是挺完善的。

gorilla/websocket帮我们解决了socket的连接和推送问题,剩下连接关系管理服务接口就是我们要关注的重点了。

【1】:连接关系管理

先来给大家上两段代码

server.go

package websocket

import (
	"fmt"
	"sync"
	"time"

	"github.com/golang-module/carbon"
	"github.com/gorilla/websocket"
)

type WebSocketClientBase struct {
	ID            string
	Conn          *websocket.Conn
	LastHeartbeat int64
	BindUid       string
	JoinGroup     []string
}

type WebSocketUserBase struct {
	Uid      string
	ClientID []string
}

type WebSocketGroupBase struct {
	ClientID []string
}

var GatewayClients, GatewayUser, GatewayGroup sync.Map

/**
 * @description: 客户端心跳检测,超时即断开连接(主要是为了降低服务端承载压力)
 * @param {string} clientID
 * @return {*}
 */
func clientHeartbeatCheck(clientID string) {

	for {

		time.Sleep(5 * time.Second)

		clientInterface, exists := GatewayClients.Load(clientID)

		if !exists {

			break
		}

		client, _ := clientInterface.(*WebSocketClientBase)

		if (carbon.Now().Timestamp() - client.LastHeartbeat) > int64(HeartbeatTime) {

			fmt.Println("Client", clientID, "heartbeat timeout")

			client.Conn.Close()
			GatewayClients.Delete(clientID)
			break
		}
	}
}

/**
 * @description: 客户端断线时自动踢出Uid绑定列表
 * @param {string} clientID
 * @param {string} uid
 * @return {*}
 */
func clientUnBindUid(clientID string, uid string) {

	value, ok := GatewayUser.Load(uid)

	if ok {

		users := value.(*WebSocketUserBase)

		for k, v := range users.ClientID {

			if v == clientID {

				users.ClientID = append(users.ClientID[:k], users.ClientID[k+1:]...)
			}
		}

		if len(users.ClientID) == 0 {

			GatewayUser.Delete(uid)
		}

	}
}

/**
 * @description: 客户端断线时自动踢出已加入的群组
 * @param {string} clientID
 * @return {*}
 */
func clientLeaveGroup(clientID string) {
	// 使用 Load 方法获取值
	value, ok := GatewayClients.Load(clientID)
	if !ok {
		// 如果没有找到对应的值,处理相应的逻辑
		return
	}

	client := value.(*WebSocketClientBase)

	// 遍历 JoinGroup
	for _, v := range client.JoinGroup {
		// 使用 Load 方法获取值
		groupValue, groupOK := GatewayGroup.Load(v)
		if !groupOK {
			// 如果没有找到对应的值,处理相应的逻辑
			continue
		}

		group := groupValue.(*WebSocketGroupBase)

		// 在群组中找到对应的 clientID,并删除
		for j, id := range group.ClientID {
			if id == clientID {
				copy(group.ClientID[j:], group.ClientID[j+1:])
				group.ClientID = group.ClientID[:len(group.ClientID)-1]

				// 如果群组中没有成员了,删除群组
				if len(group.ClientID) == 0 {
					GatewayGroup.Delete(v)
				}

				break
			}
		}
	}
}

connect.go

package websocket

import (
	"fmt"
	"gateway-websocket/config"
	"net/http"
	"runtime/debug"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/golang-module/carbon"
	"github.com/google/uuid"
	"github.com/gorilla/websocket"
)

var (
	upGrader = websocket.Upgrader{
		// 设置消息接收缓冲区大小(byte),如果这个值设置得太小,可能会导致服务端在读取客户端发送的大型消息时遇到问题
		ReadBufferSize: config.GatewayConfig["ReadBufferSize"].(int),
		// 设置消息发送缓冲区大小(byte),如果这个值设置得太小,可能会导致服务端在发送大型消息时遇到问题
		WriteBufferSize: config.GatewayConfig["WriteBufferSize"].(int),
		// 消息包启用压缩
		EnableCompression: config.GatewayConfig["MessageCompression"].(bool),
		// ws握手超时时间
		HandshakeTimeout: time.Duration(config.GatewayConfig["WebsocketHandshakeTimeout"].(int)) * time.Second,
		// ws握手过程中允许跨域
		CheckOrigin: func(r *http.Request) bool {
			return true
		},
	}

	// 设置心跳检测间隔时长(秒)
	HeartbeatTime = config.GatewayConfig["HeartbeatTimeout"].(int)
)

/**
 * @description: 初始化客户端连接
 * @param {*websocket.Conn} conn
 * @return {*}
 */
func handleClientInit(conn *websocket.Conn) string {

	clientID := uuid.New().String()

	client := &WebSocketClientBase{
		ID:            clientID,
		Conn:          conn,
		LastHeartbeat: carbon.Now().Timestamp(),
	}

	// 使用 Store 方法存储值
	GatewayClients.Store(clientID, client)

	if err := conn.WriteMessage(config.GatewayConfig["MessageFormat"].(int), []byte(clientID)); err != nil {

		handleClientDisconnect(clientID)
		return ""
	}

	return clientID
}

/**
 * @description: 主动关闭客户端连接
 * @param {string} clientID
 * @return {*}
 */
func handleClientDisconnect(clientID string) {

	// 使用 Load 和 Delete 方法,不需要额外的锁定操作
	v, ok := GatewayClients.Load(clientID)
	if ok {

		client := v.(*WebSocketClientBase)

		if client.BindUid != "" {
			clientUnBindUid(clientID, client.BindUid)
		}

		if len(client.JoinGroup) > 0 {
			clientLeaveGroup(clientID)
		}

		GatewayClients.Delete(clientID)
	}
}

/**
 * @description: 向客户端回复心跳消息
 * @param {*websocket.Conn} conn
 * @param {string} clientID
 * @param {int} messageType
 * @param {[]byte} message
 * @return {*}
 */
func handleClientMessage(conn *websocket.Conn, clientID string, messageType int, message []byte) {

	// 使用 Load 方法获取值
	v, ok := GatewayClients.Load(clientID)
	if !ok {
		// 如果没有找到对应的值,处理相应的逻辑
		handleClientDisconnect(clientID)
		return
	}

	client := v.(*WebSocketClientBase)

	if messageType == config.GatewayConfig["MessageFormat"].(int) && string(message) == "ping" {

		if err := conn.WriteMessage(config.GatewayConfig["MessageFormat"].(int), []byte("pong")); err != nil {

			handleClientDisconnect(clientID)
			return
		}

		GatewayClients.Store(clientID, &WebSocketClientBase{
			ID:            clientID,
			Conn:          conn,
			LastHeartbeat: carbon.Now().Timestamp(),
			BindUid:       client.BindUid,
			JoinGroup:     client.JoinGroup,
		})
	}
}

func WsServer(c *gin.Context) {

	defer func() {
		if err := recover(); err != nil {
			fmt.Printf("WsServer panic: %v\n", err)
			debug.PrintStack()
		}
	}()

	// 将 HTTP 连接升级为 WebSocket 连接
	conn, err := upGrader.Upgrade(c.Writer, c.Request, nil)

	if err != nil {
		return
	}

	defer conn.Close()

	// 客户端唯一身份标识
	clientID := handleClientInit(conn)

	// 发送客户端唯一标识 ID
	if clientID == "" {
		return
	}

	go clientHeartbeatCheck(clientID)

	for {

		// 读取客户端发送过来的消息
		messageType, message, err := conn.ReadMessage()

		// 当收到err时则标识客户端连接出现异常,如断线
		if err != nil {

			handleClientDisconnect(clientID)

		} else {

			handleClientMessage(conn, clientID, messageType, message)
		}
	}

}

在上面的代码中,我创建了一个websocket的连接服务和使用了3个sync.Map来分别存放管理不同的客户端连接

(在做这种存在高并发场景的业务时不要使用Map而是用sync.Map,因为go的Map是非线程安全的,在并发时会造成资源竞争从而导致你的程序宕掉,这点一定要注意!!!)

Stop,文章好像被拉的太长了(⊙o⊙)…,那就只展示一点点吧,其他的代码和php操作Demo都完全开源到github啦,大家自取哈。

测试时间

代码写完,先把程序run起来

然后压测安排上

大家可以在自己电脑上试试看,我这个Jmeter不知道什么原因,线程数超过1000后就运行很慢了

(单纯是Jmeter慢,不是go哈,也可能是我电脑的问题)

以上就是一文带你使用golang手撸一个websocket中间件的详细内容,更多关于go websocket中间件的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文