From 72bf459d79e6bcfcb215d512a9ef28428964b36b Mon Sep 17 00:00:00 2001 From: ChillerDragon Date: Sat, 22 Jun 2024 12:25:29 +0800 Subject: [PATCH] Refactor: separate packet parsing and connection state Computing a response to a packet is now separate from unpacking a packet --- messages7/ctrl_connect.go | 3 +- messages7/ctrl_token.go | 3 +- messages7/snap_single.go | 61 ++++++++++++ messages7/unknown.go | 8 +- protocol7/connection.go | 204 ++++++++++++-------------------------- protocol7/packet.go | 11 +- protocol7/packet_test.go | 8 +- teeworlds.go | 32 +++--- 8 files changed, 158 insertions(+), 172 deletions(-) create mode 100644 messages7/snap_single.go diff --git a/messages7/ctrl_connect.go b/messages7/ctrl_connect.go index 5b6bb55..502b605 100644 --- a/messages7/ctrl_connect.go +++ b/messages7/ctrl_connect.go @@ -36,9 +36,8 @@ func (msg CtrlConnect) Pack() []byte { ) } -// TODO: no idea if this works func (msg *CtrlConnect) Unpack(u *packer.Unpacker) { - msg.Token = [4]byte(u.Data()) + msg.Token = [4]byte(u.Rest()) } func (msg *CtrlConnect) Header() *chunk7.ChunkHeader { diff --git a/messages7/ctrl_token.go b/messages7/ctrl_token.go index b93e99e..c21c9f8 100644 --- a/messages7/ctrl_token.go +++ b/messages7/ctrl_token.go @@ -39,9 +39,8 @@ func (msg CtrlToken) Pack() []byte { ) } -// TODO: no idea if this works func (msg *CtrlToken) Unpack(u *packer.Unpacker) { - msg.Token = [4]byte(u.Data()) + msg.Token = [4]byte(u.Rest()) } func (msg *CtrlToken) Header() *chunk7.ChunkHeader { diff --git a/messages7/snap_single.go b/messages7/snap_single.go new file mode 100644 index 0000000..5af71b4 --- /dev/null +++ b/messages7/snap_single.go @@ -0,0 +1,61 @@ +package messages7 + +import ( + "slices" + + "github.com/teeworlds-go/teeworlds/chunk7" + "github.com/teeworlds-go/teeworlds/network7" + "github.com/teeworlds-go/teeworlds/packer" +) + +type SnapSingle struct { + ChunkHeader *chunk7.ChunkHeader + + GameTick int + DeltaTick int + Crc int + PartSize int + Data []byte +} + +func (msg SnapSingle) MsgId() int { + return network7.MsgSysSnapSingle +} + +func (msg SnapSingle) MsgType() network7.MsgType { + return network7.TypeNet +} + +func (msg SnapSingle) System() bool { + return true +} + +func (msg SnapSingle) Vital() bool { + return false +} + +func (msg SnapSingle) Pack() []byte { + return slices.Concat( + packer.PackInt(msg.GameTick), + packer.PackInt(msg.DeltaTick), + packer.PackInt(msg.Crc), + packer.PackInt(msg.PartSize), + msg.Data[:], + ) +} + +func (msg *SnapSingle) Unpack(u *packer.Unpacker) { + msg.GameTick = u.GetInt() + msg.DeltaTick = u.GetInt() + msg.Crc = u.GetInt() + msg.PartSize = u.GetInt() + msg.Data = u.Rest() +} + +func (msg *SnapSingle) Header() *chunk7.ChunkHeader { + return msg.ChunkHeader +} + +func (msg *SnapSingle) SetHeader(header *chunk7.ChunkHeader) { + msg.ChunkHeader = header +} diff --git a/messages7/unknown.go b/messages7/unknown.go index acbc49e..50d7b64 100644 --- a/messages7/unknown.go +++ b/messages7/unknown.go @@ -39,8 +39,7 @@ func (msg Unknown) System() bool { } func (msg Unknown) Vital() bool { - // TODO: check is not ctrl and then unpack Data - panic("not implemented yet") + panic("You are not mean't to pack unknown messages. Use msg.Header().Vital instead.") } func (msg Unknown) Pack() []byte { @@ -52,7 +51,10 @@ func (msg *Unknown) Unpack(u *packer.Unpacker) { } func (msg *Unknown) Header() *chunk7.ChunkHeader { - return nil + if msg.Type == network7.TypeControl { + return nil + } + return msg.ChunkHeader } func (msg *Unknown) SetHeader(header *chunk7.ChunkHeader) { diff --git a/protocol7/connection.go b/protocol7/connection.go index ede6658..61dcb5f 100644 --- a/protocol7/connection.go +++ b/protocol7/connection.go @@ -4,13 +4,9 @@ import ( "bytes" "fmt" "os" - "slices" - "github.com/teeworlds-go/huffman" - "github.com/teeworlds-go/teeworlds/chunk7" "github.com/teeworlds-go/teeworlds/messages7" "github.com/teeworlds-go/teeworlds/network7" - "github.com/teeworlds-go/teeworlds/packer" ) type Player struct { @@ -96,19 +92,20 @@ func byteSliceToString(s []byte) string { return string(s) } -func (connection *Connection) OnSystemMsg(msg int, chunk chunk7.Chunk, u *packer.Unpacker, result *PacketResult) bool { - if msg == network7.MsgSysMapChange { +func (connection *Connection) OnSystemMsg(msg messages7.NetMessage, response *Packet) bool { + // TODO: is this shadow nasty? + switch msg := msg.(type) { + case *messages7.MapChange: fmt.Println("got map change") - result.Response.Messages = append(result.Response.Messages, &messages7.Ready{}) - } else if msg == network7.MsgSysConReady { + response.Messages = append(response.Messages, &messages7.Ready{}) + case *messages7.ConReady: fmt.Println("got ready") - result.Response.Messages = append(result.Response.Messages, connection.MsgStartInfo()) - } else if msg == network7.MsgSysSnapSingle { - // tick := u.GetInt() - // fmt.Printf("got snap single tick=%d\n", tick) - result.Response.Messages = append(result.Response.Messages, &messages7.CtrlKeepAlive{}) - } else { - fmt.Printf("unknown system message id=%d data=%x\n", msg, chunk.Data) + response.Messages = append(response.Messages, connection.MsgStartInfo()) + case *messages7.SnapSingle: + // fmt.Printf("got snap single tick=%d\n", msg.GameTick) + response.Messages = append(response.Messages, &messages7.CtrlKeepAlive{}) + default: + fmt.Printf("unknown system message id=%d payload=%x\n", msg.MsgId(), msg.Pack()) return false } return true @@ -123,155 +120,78 @@ func (client *Connection) OnChatMessage(msg *messages7.SvChat) { fmt.Printf("[chat] <%s> %s\n", name, msg.Message) } -func (client *Connection) OnMotd(motd string) { - fmt.Printf("[motd] %s\n", motd) -} - -func (client *Connection) OnGameMsg(msg int, chunk chunk7.Chunk, u *packer.Unpacker, result *PacketResult) bool { - if msg == network7.MsgGameReadyToEnter { +func (connection *Connection) OnGameMsg(msg messages7.NetMessage, response *Packet) bool { + // TODO: is this shadow nasty? + switch msg := msg.(type) { + case *messages7.ReadyToEnter: fmt.Println("got ready to enter") - result.Packet.Messages = append(result.Packet.Messages, &messages7.Ready{ChunkHeader: &chunk.Header}) - result.Response.Messages = append(result.Response.Messages, &messages7.EnterGame{}) - } else if msg == network7.MsgGameSvMotd { - motd := u.GetString() - if motd != "" { - client.OnMotd(motd) - } - } else if msg == network7.MsgGameSvChat { - chat := &messages7.SvChat{ChunkHeader: &chunk.Header} - chat.Unpack(u) - client.OnChatMessage(chat) - result.Packet.Messages = append(result.Packet.Messages, chat) - } else if msg == network7.MsgGameSvClientInfo { - clientId := packer.UnpackInt(chunk.Data[1:]) - client.Players[clientId].Info.Unpack(u) - - fmt.Printf("got client info id=%d name=%s\n", clientId, client.Players[clientId].Info.Name) - } else { - fmt.Printf("unknown game message id=%d data=%x\n", msg, chunk.Data) + response.Messages = append(response.Messages, &messages7.EnterGame{}) + case *messages7.SvMotd: + fmt.Printf("[motd] %s\n", msg.Message) + case *messages7.SvChat: + connection.OnChatMessage(msg) + case *messages7.SvClientInfo: + connection.Players[msg.ClientId].Info = *msg + fmt.Printf("got client info id=%d name=%s\n", msg.ClientId, msg.Name) + default: + fmt.Printf("unknown game message id=%d payload=%x\n", msg.MsgId(), msg.Pack()) return false } return true } -func (client *Connection) OnMessage(chunk chunk7.Chunk, result *PacketResult) bool { - // fmt.Printf("got chunk size=%d data=%v\n", chunk.Header.Size, chunk.Data) - - if chunk.Header.Flags.Vital { - client.Ack++ +func (connection *Connection) OnMessage(msg messages7.NetMessage, response *Packet) bool { + if msg.Header() == nil { + // this is probably an unknown message + fmt.Printf("warning ignoring msgId=%d because header is nil\n", msg.MsgId()) + return false + } + if msg.Header().Flags.Vital { + connection.Ack++ } - u := packer.Unpacker{} - u.Reset(chunk.Data) - - msg := u.GetInt() - - sys := msg&1 != 0 - msg >>= 1 - - if sys { - return client.OnSystemMsg(msg, chunk, &u, result) + if msg.System() { + return connection.OnSystemMsg(msg, response) } - return client.OnGameMsg(msg, chunk, &u, result) + return connection.OnGameMsg(msg, response) } -func (connection *Connection) OnPacketPayload(data []byte, result *PacketResult) (*PacketResult, error) { - chunks := chunk7.UnpackChunks(data) +// Takes a full teeworlds packet as argument +// And returns the response packet from the clients perspective +func (connection *Connection) OnPacket(packet *Packet) *Packet { + response := connection.BuildResponse() - for _, c := range chunks { - if connection.OnMessage(c, result) == false { - unknown := &messages7.Unknown{ - Data: slices.Concat(c.Header.Pack(), c.Data), - Type: network7.TypeNet, - } - result.Packet.Messages = append(result.Packet.Messages, unknown) - } - } - return result, nil - -} - -type PacketResult struct { - // Suggested response that should be sent to the server - // Will be *nil* if no response should be sent - Response *Packet - - // Incoming traffic from the server parsed into a Packet struct - Packet *Packet -} - -// TODO: there should be a Packet.Unpack() -// -// and it should only do the parsing no state handling or responses -// and Connection.OnPack() should take a Packet instance as parameter -// not raw data -// So ideally it would look like this: -// -// packet := Packet{} -// packet.Unpack(data) -// conn := Connection{} -// result, err := conn.OnPacket(packet) -func (connection *Connection) OnPacket(data []byte) (*PacketResult, error) { - result := &PacketResult{ - Response: connection.BuildResponse(), - Packet: &Packet{}, - } - result.Packet.Header.Unpack(data[:7]) - payload := data[7:] - - if result.Packet.Header.Flags.Control { - ctrlMsg := int(payload[0]) - fmt.Printf("got ctrl msg %d\n", ctrlMsg) - if ctrlMsg == network7.MsgCtrlToken { - copy(connection.ServerToken[:], payload[1:5]) - result.Response.Header.Token = connection.ServerToken - fmt.Printf("got server token %x\n", connection.ServerToken) - result.Packet.Messages = append(result.Packet.Messages, &messages7.CtrlToken{Token: connection.ServerToken}) - result.Response.Messages = append( - result.Response.Messages, + if packet.Header.Flags.Control { + msg := packet.Messages[0] + fmt.Printf("got ctrl msg %d\n", msg.MsgId()) + // TODO: is this shadow nasty? + switch msg := msg.(type) { + case *messages7.CtrlToken: + fmt.Printf("got server token %x\n", msg.Token) + connection.ServerToken = msg.Token + response.Header.Token = msg.Token + response.Messages = append( + response.Messages, &messages7.CtrlConnect{ Token: connection.ClientToken, }, ) - } else if ctrlMsg == network7.MsgCtrlAccept { + case *messages7.CtrlAccept: fmt.Println("got accept") - result.Packet.Messages = append(result.Packet.Messages, &messages7.CtrlAccept{}) // TODO: don't hardcode info - result.Response.Messages = append(result.Response.Messages, &messages7.Info{}) - } else if ctrlMsg == network7.MsgCtrlClose { - // TODO: get length from packet header to determine if a reason is set or not - // len(data) -> is 1400 (maxPacketLen) - - reason := byteSliceToString(payload) - fmt.Printf("disconnected (%s)\n", reason) - + response.Messages = append(response.Messages, &messages7.Info{}) + case *messages7.CtrlClose: + fmt.Printf("disconnected (%s)\n", msg.Reason) os.Exit(0) - } else { - unknown := &messages7.Unknown{ - Data: payload, - Type: network7.TypeControl, - } - result.Packet.Messages = append(result.Packet.Messages, unknown) - fmt.Printf("unknown control message: %x\n", data) + default: + fmt.Printf("unknown control message: %d\n", msg.MsgId()) } - - if len(result.Response.Messages) == 0 { - return nil, nil - } - - return result, nil + return response } - if result.Packet.Header.Flags.Compression { - huff := huffman.Huffman{} - var err error - payload, err = huff.Decompress(payload) - if err != nil { - fmt.Printf("huffman error: %v\n", err) - return nil, nil - } + for _, msg := range packet.Messages { + connection.OnMessage(msg, response) } - result, err := connection.OnPacketPayload(payload, result) - return result, err + return response } diff --git a/protocol7/packet.go b/protocol7/packet.go index 0d7e0d7..e15b002 100644 --- a/protocol7/packet.go +++ b/protocol7/packet.go @@ -89,6 +89,10 @@ func (packet *Packet) unpackSystem(msgId int, chunk chunk7.Chunk, u *packer.Unpa msg := &messages7.ConReady{ChunkHeader: &chunk.Header} msg.Unpack(u) packet.Messages = append(packet.Messages, msg) + } else if msgId == network7.MsgSysSnapSingle { + msg := &messages7.SnapSingle{ChunkHeader: &chunk.Header} + msg.Unpack(u) + packet.Messages = append(packet.Messages, msg) } else { return false } @@ -98,7 +102,7 @@ func (packet *Packet) unpackSystem(msgId int, chunk chunk7.Chunk, u *packer.Unpa func (packet *Packet) unpackGame(msgId int, chunk chunk7.Chunk, u *packer.Unpacker) bool { if msgId == network7.MsgGameReadyToEnter { - msg := &messages7.Ready{ChunkHeader: &chunk.Header} + msg := &messages7.ReadyToEnter{ChunkHeader: &chunk.Header} msg.Unpack(u) packet.Messages = append(packet.Messages, msg) } else if msgId == network7.MsgGameSvMotd { @@ -140,8 +144,9 @@ func (packet *Packet) unpackPayload(payload []byte) { for _, c := range chunks { if packet.unpackChunk(c) == false { unknown := &messages7.Unknown{ - Data: slices.Concat(c.Header.Pack(), c.Data), - Type: network7.TypeNet, + ChunkHeader: &c.Header, + Data: slices.Concat(c.Header.Pack(), c.Data), + Type: network7.TypeNet, } packet.Messages = append(packet.Messages, unknown) } diff --git a/protocol7/packet_test.go b/protocol7/packet_test.go index 405bad4..5141570 100644 --- a/protocol7/packet_test.go +++ b/protocol7/packet_test.go @@ -135,12 +135,10 @@ func TestRepackUnknownMessages(t *testing.T) { } conn := Connection{} - result, err := conn.OnPacket(dump) - if err != nil { - t.Errorf("Unexpected error %v\n", err) - } - repack := result.Packet.Pack(&conn) + packet := Packet{} + packet.Unpack(dump) + repack := packet.Pack(&conn) if !reflect.DeepEqual(repack, dump) { t.Errorf("got %v, wanted %v", repack, dump) diff --git a/teeworlds.go b/teeworlds.go index bb8adc5..ab704fc 100644 --- a/teeworlds.go +++ b/teeworlds.go @@ -86,39 +86,41 @@ func main() { time.Sleep(10_000_000) select { case msg := <-ch: - result, err := client.OnPacket(msg) + packet := &protocol7.Packet{} + err := packet.Unpack(msg) if err != nil { panic(err) } - if result != nil && result.Response != nil { + response := client.OnPacket(packet) + if response != nil { // example of inspecting incoming trafic - for i, msg := range result.Packet.Messages { - if msg.MsgId() == network7.MsgGameSvChat { - var chat *messages7.SvChat - var ok bool - if chat, ok = result.Packet.Messages[i].(*messages7.SvChat); ok { - fmt.Printf("got chat msg: %s\n", chat.Message) + for i, _ := range packet.Messages { + var chat *messages7.SvChat + var ok bool + if chat, ok = packet.Messages[i].(*messages7.SvChat); ok { + fmt.Printf("got chat msg: %s\n", chat.Message) - // modify chat if this was a proxy - result.Packet.Messages[i] = chat - } + // modify chat if this was a proxy + packet.Messages[i] = chat } } // example of modifying outgoing traffic - for i, msg := range result.Response.Messages { + for i, msg := range response.Messages { if msg.MsgId() == network7.MsgCtrlConnect { var connect *messages7.CtrlConnect var ok bool - if connect, ok = result.Response.Messages[i].(*messages7.CtrlConnect); ok { + if connect, ok = response.Messages[i].(*messages7.CtrlConnect); ok { connect.Token = [4]byte{0xaa, 0xaa, 0xaa, 0xaa} - result.Response.Messages[i] = connect + response.Messages[i] = connect } } } - conn.Write(result.Response.Pack(client)) + if len(response.Messages) > 0 || response.Header.Flags.Resend { + conn.Write(response.Pack(client)) + } } default: // do nothing