Support packing control packets
This commit is contained in:
parent
390e470f3b
commit
781248ae79
|
@ -1,4 +1,4 @@
|
|||
from twnet_parser.packet import parse7
|
||||
from twnet_parser.packet import parse7, TwPacket
|
||||
from twnet_parser.messages7.control.keep_alive import CtrlKeepAlive
|
||||
from twnet_parser.messages7.control.connect import CtrlConnect
|
||||
from twnet_parser.messages7.control.accept import CtrlAccept
|
||||
|
@ -14,6 +14,68 @@ def test_parse_7_close():
|
|||
assert packet.messages[0].message_name == 'close'
|
||||
assert len(packet.messages) == 1
|
||||
|
||||
def test_pack_7_close_packet_defaults():
|
||||
packet: TwPacket = TwPacket()
|
||||
close = CtrlClose()
|
||||
packet.messages.append(close)
|
||||
data = packet.pack()
|
||||
assert data == b'\x04\x00\x00\xff\xff\xff\xff\x04'
|
||||
|
||||
def test_pack_7_close_packet_token():
|
||||
packet: TwPacket = TwPacket()
|
||||
packet.header.token = b'\xaa\xbb\xcc\xdd'
|
||||
close = CtrlClose()
|
||||
packet.messages.append(close)
|
||||
data = packet.pack()
|
||||
assert data == b'\x04\x00\x00\xaa\xbb\xcc\xdd\x04'
|
||||
|
||||
def test_pack_7_close_packet_token_and_reason():
|
||||
packet: TwPacket = TwPacket()
|
||||
packet.header.token = b'\xaa\xbb\xcc\xdd'
|
||||
close = CtrlClose(reason = "timeout")
|
||||
packet.messages.append(close)
|
||||
data = packet.pack()
|
||||
assert data == b'\x04\x00\x00\xaa\xbb\xcc\xdd\x04timeout\x00'
|
||||
|
||||
def test_pack_7_close_packet_set_control_false():
|
||||
"""
|
||||
This packet is wrong according to tw spec
|
||||
|
||||
because control is sent but flag not set
|
||||
"""
|
||||
packet: TwPacket = TwPacket()
|
||||
packet.header.token = b'\xaa\xbb\xcc\xdd'
|
||||
packet.header.flags.control = False
|
||||
close = CtrlClose()
|
||||
packet.messages.append(close)
|
||||
data = packet.pack()
|
||||
assert data == b'\x00\x00\x00\xaa\xbb\xcc\xdd\x04'
|
||||
|
||||
def test_pack_7_close_packet_set_control_false_and_num_chunks2():
|
||||
"""
|
||||
This packet is wrong according to tw spec
|
||||
|
||||
because control is sent but flag not set
|
||||
and because num chunks is not set to zero
|
||||
"""
|
||||
packet: TwPacket = TwPacket()
|
||||
packet.header.token = b'\xaa\xbb\xcc\xdd'
|
||||
packet.header.flags.control = False
|
||||
packet.header.num_chunks = 2
|
||||
close = CtrlClose()
|
||||
packet.messages.append(close)
|
||||
data = packet.pack()
|
||||
assert data == b'\x00\x00\x02\xaa\xbb\xcc\xdd\x04'
|
||||
|
||||
def test_pack_7_close_packet_set_control_true():
|
||||
packet: TwPacket = TwPacket()
|
||||
packet.header.token = b'\xaa\xbb\xcc\xdd'
|
||||
packet.header.flags.control = True
|
||||
close = CtrlClose()
|
||||
packet.messages.append(close)
|
||||
data = packet.pack()
|
||||
assert data == b'\x04\x00\x00\xaa\xbb\xcc\xdd\x04'
|
||||
|
||||
def test_pack_7_close():
|
||||
close = CtrlClose()
|
||||
data = close.pack()
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Protocol
|
|||
class CtrlMessage(Protocol):
|
||||
message_name: str
|
||||
message_id: int
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = False) -> bool:
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = True) -> bool:
|
||||
...
|
||||
def pack(self) -> bytes:
|
||||
def pack(self, we_are_a_client: bool = True) -> bytes:
|
||||
...
|
||||
|
|
|
@ -5,8 +5,8 @@ class CtrlAccept(PrettyPrint):
|
|||
self.message_name: str = 'accept'
|
||||
self.message_id: int = 2
|
||||
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = False) -> bool:
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = True) -> bool:
|
||||
return False
|
||||
|
||||
def pack(self, client: bool = True) -> bytes:
|
||||
def pack(self, we_are_a_client: bool = True) -> bytes:
|
||||
return b''
|
||||
|
|
|
@ -17,12 +17,12 @@ class CtrlClose(PrettyPrint):
|
|||
# first byte of data
|
||||
# has to be the first byte of the message payload
|
||||
# NOT the chunk header and NOT the message id
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = False) -> bool:
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = True) -> bool:
|
||||
unpacker = Unpacker(data)
|
||||
self.reason = unpacker.get_str() # TODO: this is an optional field
|
||||
return True
|
||||
|
||||
def pack(self) -> bytes:
|
||||
def pack(self, we_are_a_client: bool = True) -> bytes:
|
||||
if self.reason:
|
||||
return pack_str(self.reason)
|
||||
return b''
|
||||
|
|
|
@ -10,12 +10,12 @@ class CtrlConnect(PrettyPrint):
|
|||
|
||||
self.response_token: bytes = response_token
|
||||
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = False) -> bool:
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = True) -> bool:
|
||||
# anti reflection attack
|
||||
if len(data) < 512:
|
||||
return False
|
||||
self.response_token = data[0:4]
|
||||
return True
|
||||
|
||||
def pack(self) -> bytes:
|
||||
def pack(self, we_are_a_client: bool = True) -> bytes:
|
||||
return self.response_token + bytes(508)
|
||||
|
|
|
@ -5,8 +5,8 @@ class CtrlKeepAlive(PrettyPrint):
|
|||
self.message_name: str = 'keep_alive'
|
||||
self.message_id: int = 0
|
||||
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = False) -> bool:
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = True) -> bool:
|
||||
return False
|
||||
|
||||
def pack(self) -> bytes:
|
||||
def pack(self, we_are_a_client: bool = True) -> bytes:
|
||||
return b''
|
||||
|
|
|
@ -10,7 +10,7 @@ class CtrlToken(PrettyPrint):
|
|||
|
||||
self.response_token: bytes = response_token
|
||||
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = False) -> bool:
|
||||
def unpack(self, data: bytes, we_are_a_client: bool = True) -> bool:
|
||||
if not we_are_a_client:
|
||||
# anti reflection attack
|
||||
if len(data) < 512:
|
||||
|
|
|
@ -26,19 +26,19 @@ CHUNKFLAG7_RESEND = 2
|
|||
PACKET_HEADER7_SIZE = 7
|
||||
|
||||
class PacketFlags7(PrettyPrint):
|
||||
def __init__(self):
|
||||
self.control = False
|
||||
self.resend = False
|
||||
self.compression = False
|
||||
self.connless = False
|
||||
def __init__(self) -> None:
|
||||
self.control: Optional[bool] = None
|
||||
self.resend: Optional[bool] = None
|
||||
self.compression: Optional[bool] = None
|
||||
self.connless: Optional[bool] = None
|
||||
|
||||
class PacketFlags6(PrettyPrint):
|
||||
def __init__(self):
|
||||
self.token = False
|
||||
self.control = False
|
||||
self.resend = False
|
||||
self.compression = False
|
||||
self.connless = False
|
||||
def __init__(self) -> None:
|
||||
self.token: Optional[bool] = None
|
||||
self.control: Optional[bool] = None
|
||||
self.resend: Optional[bool] = None
|
||||
self.compression: Optional[bool] = None
|
||||
self.connless: Optional[bool] = None
|
||||
|
||||
class PacketHeader(PrettyPrint):
|
||||
def __init__(
|
||||
|
@ -103,9 +103,10 @@ class TwPacket(PrettyPrint):
|
|||
self.header: PacketHeader = PacketHeader()
|
||||
self.messages: list[Union[CtrlMessage, NetMessage]] = []
|
||||
|
||||
def pack(self) -> bytes:
|
||||
messages: bytes = b''
|
||||
def pack(self, we_are_a_client = True) -> bytes:
|
||||
payload: bytes = b''
|
||||
msg: Union[CtrlMessage, NetMessage]
|
||||
is_control: bool = False
|
||||
for msg in self.messages:
|
||||
# TODO: this is super ugly
|
||||
# revist https://gitlab.com/teeworlds-network/twnet_parser/-/issues/1
|
||||
|
@ -113,17 +114,27 @@ class TwPacket(PrettyPrint):
|
|||
# maybe because CtrlMessage and NetMessage are no actual classes
|
||||
# but just ducktyping helpers
|
||||
if not hasattr(msg, 'system_message'):
|
||||
raise ValueError('Packing control messages is not supported yet')
|
||||
msg = cast(NetMessage, msg)
|
||||
msg_payload: bytes = pack_int((msg.message_id<<1)|(int)(msg.system_message))
|
||||
msg_payload += msg.pack()
|
||||
if msg.header.size is None:
|
||||
msg.header.size = len(msg_payload)
|
||||
messages += msg.header.pack()
|
||||
messages += msg_payload
|
||||
is_control = True
|
||||
msg = cast(CtrlMessage, msg)
|
||||
payload += pack_int(msg.message_id)
|
||||
payload += msg.pack(we_are_a_client)
|
||||
else:
|
||||
msg = cast(NetMessage, msg)
|
||||
msg_payload: bytes = pack_int((msg.message_id<<1)|(int)(msg.system_message))
|
||||
msg_payload += msg.pack()
|
||||
if msg.header.size is None:
|
||||
msg.header.size = len(msg_payload)
|
||||
payload += msg.header.pack()
|
||||
payload += msg_payload
|
||||
if self.header.num_chunks is None:
|
||||
self.header.num_chunks = len(self.messages)
|
||||
return self.header.pack() + messages
|
||||
if is_control:
|
||||
self.header.num_chunks = 0
|
||||
else:
|
||||
self.header.num_chunks = len(self.messages)
|
||||
if is_control:
|
||||
if self.header.flags.control is None:
|
||||
self.header.flags.control = True
|
||||
return self.header.pack() + payload
|
||||
|
||||
class PacketHeaderParser7():
|
||||
def parse_flags7(self, data: bytes) -> PacketFlags7:
|
||||
|
@ -243,5 +254,5 @@ class PacketParser():
|
|||
def parse6(data: bytes) -> TwPacket:
|
||||
raise NotImplementedError()
|
||||
|
||||
def parse7(data: bytes, we_are_a_client: bool = False) -> TwPacket:
|
||||
def parse7(data: bytes, we_are_a_client: bool = True) -> TwPacket:
|
||||
return PacketParser().parse7(data, we_are_a_client)
|
||||
|
|
Loading…
Reference in a new issue