From a78bf12bf0747682281af8bd038c7bccf5f3a5e3 Mon Sep 17 00:00:00 2001 From: ChillerDragon Date: Thu, 20 Jun 2024 12:29:26 +0800 Subject: [PATCH] Make OnPacket return a result struct --- protocol7/connection.go | 67 +++++++++++++++++++++++------------------ teeworlds.go | 17 +++++------ 2 files changed, 46 insertions(+), 38 deletions(-) diff --git a/protocol7/connection.go b/protocol7/connection.go index 009557f..24a7bea 100644 --- a/protocol7/connection.go +++ b/protocol7/connection.go @@ -95,17 +95,17 @@ func byteSliceToString(s []byte) string { return string(s) } -func (connection *Connection) OnSystemMsg(msg int, chunk chunk7.Chunk, u *packer.Unpacker, response *Packet) { +func (connection *Connection) OnSystemMsg(msg int, chunk chunk7.Chunk, u *packer.Unpacker, result *PacketResult) { if msg == network7.MsgSysMapChange { fmt.Println("got map change") - response.Messages = append(response.Messages, messages7.Ready{}) + result.Response.Messages = append(result.Response.Messages, messages7.Ready{}) } else if msg == network7.MsgSysConReady { fmt.Println("got ready") - response.Messages = append(response.Messages, connection.MsgStartInfo()) + result.Response.Messages = append(result.Response.Messages, connection.MsgStartInfo()) } else if msg == network7.MsgSysSnapSingle { // tick := u.GetInt() // fmt.Printf("got snap single tick=%d\n", tick) - response.Messages = append(response.Messages, messages7.CtrlKeepAlive{}) + result.Response.Messages = append(result.Response.Messages, messages7.CtrlKeepAlive{}) } else { fmt.Printf("unknown system message id=%d data=%x\n", msg, chunk.Data) } @@ -120,10 +120,10 @@ func (client *Connection) OnMotd(motd string) { fmt.Printf("[motd] %s\n", motd) } -func (client *Connection) OnGameMsg(msg int, chunk chunk7.Chunk, u *packer.Unpacker, response *Packet) { +func (client *Connection) OnGameMsg(msg int, chunk chunk7.Chunk, u *packer.Unpacker, result *PacketResult) { if msg == network7.MsgGameReadyToEnter { fmt.Println("got ready to enter") - response.Messages = append(response.Messages, messages7.EnterGame{}) + result.Response.Messages = append(result.Response.Messages, messages7.EnterGame{}) } else if msg == network7.MsgGameSvMotd { motd := u.GetString() if motd != "" { @@ -145,7 +145,7 @@ func (client *Connection) OnGameMsg(msg int, chunk chunk7.Chunk, u *packer.Unpac } } -func (client *Connection) OnMessage(chunk chunk7.Chunk, response *Packet) { +func (client *Connection) OnMessage(chunk chunk7.Chunk, result *PacketResult) { // fmt.Printf("got chunk size=%d data=%v\n", chunk.Header.Size, chunk.Data) if chunk.Header.Flags.Vital { @@ -161,39 +161,48 @@ func (client *Connection) OnMessage(chunk chunk7.Chunk, response *Packet) { msg >>= 1 if sys { - client.OnSystemMsg(msg, chunk, &u, response) + client.OnSystemMsg(msg, chunk, &u, result) } else { - client.OnGameMsg(msg, chunk, &u, response) + client.OnGameMsg(msg, chunk, &u, result) } } -func (connection *Connection) OnPacketPayload(header []byte, data []byte, response *Packet) (*Packet, error) { +func (connection *Connection) OnPacketPayload(data []byte, result *PacketResult) (*PacketResult, error) { chunks := chunk7.UnpackChunks(data) for _, c := range chunks { - connection.OnMessage(c, response) + connection.OnMessage(c, result) } - return response, nil + return result, nil } -// the response packet might be nil even if there is no error -func (connection *Connection) OnPacket(data []byte) (*Packet, error) { - header := PacketHeader{} - headerRaw := data[:7] - payload := data[7:] - header.Unpack(headerRaw) - response := connection.BuildResponse() +type PacketResult struct { + // Suggested response that should be sent to the server + // Will be *nil* if no response should be sent + Response *Packet - if header.Flags.Control { + // Incoming traffic from the server parsed into a Packet struct + Packet *Packet +} + +func (connection *Connection) OnPacket(data []byte) (*PacketResult, error) { + result := &PacketResult{ + Response: connection.BuildResponse(), + Packet: &Packet{}, + } + result.Packet.Header.Unpack(data[:7]) + payload := data[7:] + + if result.Packet.Header.Flags.Control { ctrlMsg := int(payload[0]) fmt.Printf("got ctrl msg %d\n", ctrlMsg) if ctrlMsg == network7.MsgCtrlToken { copy(connection.ServerToken[:], payload[1:5]) - response.Header.Token = connection.ServerToken + result.Response.Header.Token = connection.ServerToken fmt.Printf("got server token %x\n", connection.ServerToken) - response.Messages = append( - response.Messages, + result.Response.Messages = append( + result.Response.Messages, messages7.CtrlConnect{ Token: connection.ClientToken, }, @@ -201,7 +210,7 @@ func (connection *Connection) OnPacket(data []byte) (*Packet, error) { } else if ctrlMsg == network7.MsgCtrlAccept { fmt.Println("got accept") // TODO: don't hardcode info - response.Messages = append(response.Messages, messages7.Info{}) + result.Response.Messages = append(result.Response.Messages, messages7.Info{}) } else if ctrlMsg == network7.MsgCtrlClose { // TODO: get length from packet header to determine if a reason is set or not // len(data) -> is 1400 (maxPacketLen) @@ -214,14 +223,14 @@ func (connection *Connection) OnPacket(data []byte) (*Packet, error) { fmt.Printf("unknown control message: %x\n", data) } - if len(response.Messages) == 0 { + if len(result.Response.Messages) == 0 { return nil, nil } - return response, nil + return result, nil } - if header.Flags.Compression { + if result.Packet.Header.Flags.Compression { huff := huffman.Huffman{} var err error payload, err = huff.Decompress(payload) @@ -231,6 +240,6 @@ func (connection *Connection) OnPacket(data []byte) (*Packet, error) { } } - response, err := connection.OnPacketPayload(headerRaw, payload, response) - return response, err + result, err := connection.OnPacketPayload(payload, result) + return result, err } diff --git a/teeworlds.go b/teeworlds.go index aace09e..8fe42a0 100644 --- a/teeworlds.go +++ b/teeworlds.go @@ -60,34 +60,33 @@ func main() { go readNetwork(ch, conn) - packet := client.CtrlToken() - conn.Write(packet.Pack(client)) + tokenPacket := client.CtrlToken() + conn.Write(tokenPacket.Pack(client)) for { time.Sleep(10_000_000) select { case msg := <-ch: - packet, err = client.OnPacket(msg) + result, err := client.OnPacket(msg) if err != nil { panic(err) } - if packet != nil { + if result.Response != nil { // example of modifying outgoing traffic - for i, msg := range packet.Messages { + for i, msg := range result.Response.Messages { if msg.MsgId() == network7.MsgCtrlConnect { - if connect, ok := packet.Messages[0].(messages7.CtrlConnect); ok { + if connect, ok := result.Response.Messages[0].(messages7.CtrlConnect); ok { connect.Token = [4]byte{0xaa, 0xaa, 0xaa, 0xaa} - packet.Messages[i] = connect + result.Response.Messages[i] = connect } } } - conn.Write(packet.Pack(client)) + conn.Write(result.Response.Pack(client)) } default: // do nothing } } - }