我们可以很简单得用github.com/gorilla/websocket
来实现一个websocket,先用GET请求进行握手,然后直接升级协议
govar upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// 允许所有的跨域请求
return true
},
}
我们先来获取一个用于升级协议的对象。
gofunc echoHandler(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println("WebSocket upgrade error:", err)
return
}
defer conn.Close()
for {
// 读取客户端发送的消息
_, message, err := conn.ReadMessage()
if err != nil {
log.Println("WebSocket read error:", err)
break
}
// 打印接收到的消息
log.Printf("Received message from client: %s", message)
// 发送消息给客户端
err = conn.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Println("WebSocket write error:", err)
break
}
}
}
定义一个简单的处理函数
go http.HandleFunc("/echo", echoHandler)//默认处理GET请求
log.Fatal(http.ListenAndServe(":8080", nil))
开始监听这个服务
接下来实现一个简单的客户端,建立起一个简单的websocket链接
gopackage main
import (
"fmt"
"log"
"net/url"
"os"
"os/signal"
"time"
"github.com/gorilla/websocket"
)
func main() {
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
u := url.URL{Scheme: "ws", Host: "localhost:3000", Path: "/ws"}
log.Printf("connecting to %s", u.String())
c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
log.Fatal("WebSocket dial error:", err)
}
defer c.Close()
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Println("WebSocket read error:", err)
return
}
log.Printf("Received message from server: %s", message)
}
}()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-done:
return
case t := <-ticker.C:
message := []byte(fmt.Sprintf("Message sent at %v", t))
err := c.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Println("WebSocket write error:", err)
return
}
log.Printf("Sent message to server: %s", message)
case <-interrupt:
log.Println("Interrupt signal received")
err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
log.Println("WebSocket close error:", err)
return
}
select {
case <-done:
case <-time.After(time.Second):
}
return
}
}
}
利用这个定时器来触发消息发送,运行这两个程序,可以实现双向通信,可以看到他们不断发送信息,你可以着手实现自己的业务逻辑。
gin框架中使用也很简单就是,把request和response从Context中取出就行。
gofunc SendWs(c *gin.Context) {
ws, err := websocket.Upgrade(c.Writer, c.Request, nil, 1024, 1024)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "协议升级失败!"})
return
}
for {
// 读取客户端发送的消息
_, msg, err := ws.ReadMessage()
if err != nil {
break
}
// 处理接收到的消息
// 这里可以根据业务逻辑进行处理
fmt.Println("收到:", string(msg))
rep := "收到啦奥特曼!"
// 回复消息给客户端
err = ws.WriteMessage(websocket.TextMessage, []byte(rep))
if err != nil {
// 处理回复消息失败的情况
break
}
}
// 关闭WebSocket连接
ws.Close()
}
可以看到,我们利用了websocker.Upgrader
来升级协议,不是使用升级协议的那个对象,其实底层也是一样
gofunc Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader)
}
只是设置了缓冲区大小,原来是直接创建对象的时候设置。
我们看一下这个Upgrade
方法
gofunc (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
// 定义错误消息
const badHandshake = "websocket: the client is not using the websocket protocol: "
// 检查 'Connection' 头是否包含 'upgrade' 标记
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
}
// 检查 'Upgrade' 头是否包含 'websocket' 标记
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
// 检查请求方法是否为 GET
if r.Method != http.MethodGet {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
}
// 检查 'Sec-Websocket-Version' 头是否为 13(WebSocket 版本)
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
}
// 检查是否存在 'Sec-Websocket-Extensions' 头,不支持应用程序特定的扩展
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
}
// 验证请求的来源是否被允许,可以通过 Upgrader.CheckOrigin 配置
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
}
// 从请求头中获取 'Sec-Websocket-Key',该键用于生成响应头的 'Sec-WebSocket-Accept'
challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header is missing or blank")
}
// 选择子协议(Subprotocol)用于通信,如果有多个可选的子协议
subprotocol := u.selectSubprotocol(r, responseHeader)
// 协商 PMCE(Per-Message Compression Extension),用于数据传输时的数据压缩
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
// 使用 http.Hijacker 接口获取底层连接
h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
// 如果有数据在握手完成前发送,关闭连接
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// 重用 hijacked 缓冲读取器作为连接读取器
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// 重用 hijacked 写入缓冲作为连接缓冲
writeBuf = buf
}
// 创建新的 Conn 对象表示 WebSocket 连接
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol
// 如果启用压缩,设置压缩和解压缩方法
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
// 使用 hijacked 缓冲区和连接缓冲区作为头部
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
// 构建 WebSocket 握手响应头部
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// 防止响应分割
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
// 清除 HTTP 服务器设置的截止日期
netConn.SetDeadline(time.Time{})
// 如果设置了握手超时,设置写入截止日期
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
// 发送 WebSocket 握手响应头部到客户端
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
// 如果设置了握手超时,取消写入截止日期
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
}
主要做了一下事情
检查请求头中的Connection和Upgrade字段是否包含正确的值。 检查请求方法是否为GET。
检查请求头中的Sec-Websocket-Version字段是否为支持的版本。
检查响应头中是否包含非法的Sec-Websocket-Extensions字段。
检查请求来源是否通过CheckOrigin函数的验证。
获取握手请求中的Sec-Websocket-Key字段,用于生成响应头中的Sec-WebSocket-Accept字段。
选择子协议。
根据是否启用压缩,选择合适的压缩算法。
使用http.Hijacker接口将http.ResponseWriter转换为底层的网络连接。
创建*Conn对象,并设置相关属性,如缓冲区大小、子协议等。
构造握手响应头,并发送给客户端。
返回升级完成的*Conn对象。
最近拿着获取到的的对象进行通信。
本文作者:yowayimono
本文链接:
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!