feat: implement dict on TwPacket

This commit is contained in:
ChillerDragon 2024-06-16 11:36:07 +08:00
parent 5581929e76
commit 355699d33e
8 changed files with 155 additions and 3 deletions

73
tests/dict_cast_test.py Normal file
View file

@ -0,0 +1,73 @@
from typing import cast
from twnet_parser.messages6.game.cl_change_info import MsgClChangeInfo
from twnet_parser.packet import TwPacket, NetMessage
def test_change_info() -> None:
change_info = dict(MsgClChangeInfo())
assert change_info['message_type'] == 'game'
assert change_info['message_name'] == 'cl_change_info'
assert change_info['system_message'] == False
assert change_info['message_id'] == 21
assert change_info['header'] == {'flags': [], 'seq': -1, 'size': None, 'version': '0.6'}
assert change_info['name'] == 'default'
assert change_info['clan'] == 'default'
assert change_info['country'] == 0
assert change_info['skin'] == 'default'
assert change_info['use_custom_color'] == False
assert change_info['color_body'] == 0
assert change_info['color_feet'] == 0
def test_empty_packet() -> None:
packet = dict(TwPacket())
assert packet == {
'header': {
'ack': 0,
'flags': [],
'num_chunks': None,
'token': b'\xff\xff\xff\xff'
},
'payload_decompressed': b'',
'payload_raw': b'',
'version': '0.7',
'messages': []
}
def test_change_info_packet() -> None:
packet = TwPacket()
packet.messages.append(cast(NetMessage, MsgClChangeInfo()))
packet = dict(packet)
assert packet == {
'header': {
'ack': 0,
'flags': [],
'num_chunks': None,
'token': b'\xff\xff\xff\xff'
},
'payload_decompressed': b'',
'payload_raw': b'',
'version': '0.7',
'messages': [
{
'clan': 'default',
'color_body': 0,
'color_feet': 0,
'country': 0,
'header': {
'flags': [],
'seq': -1,
'size': None,
'version': '0.6'
},
'message_id': 21,
'message_name': 'cl_change_info',
'message_type': 'game',
'name': 'default',
'skin': 'default',
'system_message': False,
'use_custom_color': False
}
]
}

View file

@ -1,4 +1,4 @@
from typing import Protocol, Literal, Annotated
from typing import Protocol, Literal, Annotated, Iterator, Any
class ConnlessMessage(Protocol):
message_type: Literal['connless']
@ -8,3 +8,5 @@ class ConnlessMessage(Protocol):
...
def pack(self) -> bytes:
...
def __iter__(self) -> Iterator[tuple[str, Any]]:
...

View file

@ -1,4 +1,4 @@
from typing import Protocol, Literal
from typing import Protocol, Literal, Iterator, Any
class CtrlMessage(Protocol):
message_type: Literal['control']
@ -8,3 +8,5 @@ class CtrlMessage(Protocol):
...
def pack(self, we_are_a_client: bool = True) -> bytes:
...
def __iter__(self) -> Iterator[tuple[str, Any]]:
...

View file

@ -19,6 +19,15 @@ class MsgDDNetUuid(PrettyPrint):
self.payload: bytes = payload
def __iter__(self):
yield 'message_type', self.message_type
yield 'message_name', self.message_name
yield 'system_message', self.system_message
yield 'message_id', self.message_id
yield 'header', dict(self.header)
yield 'payload', self.payload
# first byte of data
# has to be the first byte of the message payload
# NOT the chunk header and NOT the message id

View file

@ -35,6 +35,24 @@ class MsgSnap(PrettyPrint):
self.data: bytes = data
self.snapshot = Snapshot()
def __iter__(self):
yield 'message_type', self.message_type
yield 'message_name', self.message_name
yield 'system_message', self.system_message
yield 'message_id', self.message_id
yield 'header', dict(self.header)
yield 'tick', self.tick
yield 'delta_tick', self.delta_tick
yield 'num_parts', self.num_parts
yield 'part', self.part
yield 'crc', self.crc
yield 'data_size', self.data_size
yield 'data', self.data
# TODO: dict snapshot
# yield 'snapshot', 'TODO'
# first byte of data
# has to be the first byte of the message payload
# NOT the chunk header and NOT the message id

View file

@ -35,6 +35,24 @@ class MsgSnap(PrettyPrint):
self.data: bytes = data
self.snapshot = Snapshot()
def __iter__(self):
yield 'message_type', self.message_type
yield 'message_name', self.message_name
yield 'system_message', self.system_message
yield 'message_id', self.message_id
yield 'header', dict(self.header)
yield 'tick', self.tick
yield 'delta_tick', self.delta_tick
yield 'num_parts', self.num_parts
yield 'part', self.part
yield 'crc', self.crc
yield 'data_size', self.data_size
yield 'data', self.data
# TODO: dict snapshot
# yield 'snapshot', 'TODO'
# first byte of data
# has to be the first byte of the message payload
# NOT the chunk header and NOT the message id

View file

@ -1,4 +1,4 @@
from typing import Protocol, Literal
from typing import Protocol, Literal, Iterator, Any
from twnet_parser.chunk_header import ChunkHeader
@ -12,3 +12,5 @@ class NetMessage(Protocol):
...
def pack(self) -> bytes:
...
def __iter__(self) -> Iterator[tuple[str, Any]]:
...

View file

@ -109,6 +109,16 @@ class PacketHeader6(PrettyPrint):
self.connless_version: int = NET_PACKETVERSION
self.response_token: bytes = b'\xff\xff\xff\xff'
def __iter__(self):
yield 'flags', list(self.flags)
yield 'ack', self.ack
yield 'token', self.token
yield 'num_chunks', self.num_chunks
if self.flags.connless:
yield 'connless_version', self.connless_version
yield 'response_token', self.response_token
def pack(self) -> bytes:
"""
Generate 7 byte teeworlds 0.6.5 packet header
@ -182,6 +192,16 @@ class PacketHeader7(PrettyPrint):
self.connless_version: int = NET_PACKETVERSION
self.response_token: bytes = b'\xff\xff\xff\xff'
def __iter__(self):
yield 'flags', list(self.flags)
yield 'ack', self.ack
yield 'token', self.token
yield 'num_chunks', self.num_chunks
if self.flags.connless:
yield 'connless_version', self.connless_version
yield 'response_token', self.response_token
def pack(self) -> bytes:
"""
Generate 7 byte teeworlds 0.7 packet header
@ -237,6 +257,14 @@ class TwPacket(PrettyPrint):
raise ValueError(f"Error: invalid packet version '{self.version}'")
self.messages: list[Union[CtrlMessage, NetMessage, ConnlessMessage]] = []
def __iter__(self):
yield 'version', self.version
yield 'payload_raw', self.payload_raw
yield 'payload_decompressed', self.payload_decompressed
yield 'header', dict(self.header)
yield 'messages', [dict(msg) for msg in self.messages]
@property
def version(self) -> Literal['0.6', '0.7']:
return self._version