编辑
2023-09-21
后端
00
请注意,本文编写于 597 天前,最后修改于 431 天前,其中某些信息可能已经过时。

我们可以很简单得用github.com/gorilla/websocket来实现一个websocket,先用GET请求进行握手,然后直接升级协议

go
var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { // 允许所有的跨域请求 return true }, }

我们先来获取一个用于升级协议的对象。

go
func echoHandler(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println("WebSocket upgrade error:", err) return } defer conn.Close() for { // 读取客户端发送的消息 _, message, err := conn.ReadMessage() if err != nil { log.Println("WebSocket read error:", err) break } // 打印接收到的消息 log.Printf("Received message from client: %s", message) // 发送消息给客户端 err = conn.WriteMessage(websocket.TextMessage, message) if err != nil { log.Println("WebSocket write error:", err) break } } }

定义一个简单的处理函数

go
http.HandleFunc("/echo", echoHandler)//默认处理GET请求 log.Fatal(http.ListenAndServe(":8080", nil))

开始监听这个服务

接下来实现一个简单的客户端,建立起一个简单的websocket链接

go
package main import ( "fmt" "log" "net/url" "os" "os/signal" "time" "github.com/gorilla/websocket" ) func main() { interrupt := make(chan os.Signal, 1) signal.Notify(interrupt, os.Interrupt) u := url.URL{Scheme: "ws", Host: "localhost:3000", Path: "/ws"} log.Printf("connecting to %s", u.String()) c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { log.Fatal("WebSocket dial error:", err) } defer c.Close() done := make(chan struct{}) go func() { defer close(done) for { _, message, err := c.ReadMessage() if err != nil { log.Println("WebSocket read error:", err) return } log.Printf("Received message from server: %s", message) } }() ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { case <-done: return case t := <-ticker.C: message := []byte(fmt.Sprintf("Message sent at %v", t)) err := c.WriteMessage(websocket.TextMessage, message) if err != nil { log.Println("WebSocket write error:", err) return } log.Printf("Sent message to server: %s", message) case <-interrupt: log.Println("Interrupt signal received") err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { log.Println("WebSocket close error:", err) return } select { case <-done: case <-time.After(time.Second): } return } } }

利用这个定时器来触发消息发送,运行这两个程序,可以实现双向通信,可以看到他们不断发送信息,你可以着手实现自己的业务逻辑。

gin框架中使用也很简单就是,把request和response从Context中取出就行。

go
func SendWs(c *gin.Context) { ws, err := websocket.Upgrade(c.Writer, c.Request, nil, 1024, 1024) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "协议升级失败!"}) return } for { // 读取客户端发送的消息 _, msg, err := ws.ReadMessage() if err != nil { break } // 处理接收到的消息 // 这里可以根据业务逻辑进行处理 fmt.Println("收到:", string(msg)) rep := "收到啦奥特曼!" // 回复消息给客户端 err = ws.WriteMessage(websocket.TextMessage, []byte(rep)) if err != nil { // 处理回复消息失败的情况 break } } // 关闭WebSocket连接 ws.Close() }

可以看到,我们利用了websocker.Upgrader来升级协议,不是使用升级协议的那个对象,其实底层也是一样

go
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) { u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize} u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) { // don't return errors to maintain backwards compatibility } u.CheckOrigin = func(r *http.Request) bool { // allow all connections by default return true } return u.Upgrade(w, r, responseHeader) }

只是设置了缓冲区大小,原来是直接创建对象的时候设置。

我们看一下这个Upgrade方法

go
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { // 定义错误消息 const badHandshake = "websocket: the client is not using the websocket protocol: " // 检查 'Connection' 头是否包含 'upgrade' 标记 if !tokenListContainsValue(r.Header, "Connection", "upgrade") { return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") } // 检查 'Upgrade' 头是否包含 'websocket' 标记 if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") } // 检查请求方法是否为 GET if r.Method != http.MethodGet { return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") } // 检查 'Sec-Websocket-Version' 头是否为 13(WebSocket 版本) if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") } // 检查是否存在 'Sec-Websocket-Extensions' 头,不支持应用程序特定的扩展 if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") } // 验证请求的来源是否被允许,可以通过 Upgrader.CheckOrigin 配置 checkOrigin := u.CheckOrigin if checkOrigin == nil { checkOrigin = checkSameOrigin } if !checkOrigin(r) { return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") } // 从请求头中获取 'Sec-Websocket-Key',该键用于生成响应头的 'Sec-WebSocket-Accept' challengeKey := r.Header.Get("Sec-Websocket-Key") if challengeKey == "" { return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank") } // 选择子协议(Subprotocol)用于通信,如果有多个可选的子协议 subprotocol := u.selectSubprotocol(r, responseHeader) // 协商 PMCE(Per-Message Compression Extension),用于数据传输时的数据压缩 var compress bool if u.EnableCompression { for _, ext := range parseExtensions(r.Header) { if ext[""] != "permessage-deflate" { continue } compress = true break } } // 使用 http.Hijacker 接口获取底层连接 h, ok := w.(http.Hijacker) if !ok { return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") } var brw *bufio.ReadWriter netConn, brw, err := h.Hijack() if err != nil { return u.returnError(w, r, http.StatusInternalServerError, err.Error()) } // 如果有数据在握手完成前发送,关闭连接 if brw.Reader.Buffered() > 0 { netConn.Close() return nil, errors.New("websocket: client sent data before handshake is complete") } var br *bufio.Reader if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { // 重用 hijacked 缓冲读取器作为连接读取器 br = brw.Reader } buf := bufioWriterBuffer(netConn, brw.Writer) var writeBuf []byte if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { // 重用 hijacked 写入缓冲作为连接缓冲 writeBuf = buf } // 创建新的 Conn 对象表示 WebSocket 连接 c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) c.subprotocol = subprotocol // 如果启用压缩,设置压缩和解压缩方法 if compress { c.newCompressionWriter = compressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover } // 使用 hijacked 缓冲区和连接缓冲区作为头部 p := buf if len(c.writeBuf) > len(p) { p = c.writeBuf } p = p[:0] // 构建 WebSocket 握手响应头部 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, computeAcceptKey(challengeKey)...) p = append(p, "\r\n"...) if c.subprotocol != "" { p = append(p, "Sec-WebSocket-Protocol: "...) p = append(p, c.subprotocol...) p = append(p, "\r\n"...) } if compress { p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) } for k, vs := range responseHeader { if k == "Sec-Websocket-Protocol" { continue } for _, v := range vs { p = append(p, k...) p = append(p, ": "...) for i := 0; i < len(v); i++ { b := v[i] if b <= 31 { // 防止响应分割 b = ' ' } p = append(p, b) } p = append(p, "\r\n"...) } } p = append(p, "\r\n"...) // 清除 HTTP 服务器设置的截止日期 netConn.SetDeadline(time.Time{}) // 如果设置了握手超时,设置写入截止日期 if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout)) } // 发送 WebSocket 握手响应头部到客户端 if _, err = netConn.Write(p); err != nil { netConn.Close() return nil, err } // 如果设置了握手超时,取消写入截止日期 if u.HandshakeTimeout > 0 { netConn.SetWriteDeadline(time.Time{}) } return c, nil }

主要做了一下事情

  • 检查请求头中的Connection和Upgrade字段是否包含正确的值。 检查请求方法是否为GET。

  • 检查请求头中的Sec-Websocket-Version字段是否为支持的版本。

  • 检查响应头中是否包含非法的Sec-Websocket-Extensions字段。

  • 检查请求来源是否通过CheckOrigin函数的验证。

  • 获取握手请求中的Sec-Websocket-Key字段,用于生成响应头中的Sec-WebSocket-Accept字段。

  • 选择子协议。

  • 根据是否启用压缩,选择合适的压缩算法。

  • 使用http.Hijacker接口将http.ResponseWriter转换为底层的网络连接。

  • 创建*Conn对象,并设置相关属性,如缓冲区大小、子协议等。

  • 构造握手响应头,并发送给客户端。

  • 返回升级完成的*Conn对象。

最近拿着获取到的的对象进行通信。

本文作者:yowayimono

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!