From 16f2899e3f586cef4e48423b04236538e372c36c Mon Sep 17 00:00:00 2001 From: heinrich5991 Date: Wed, 4 Jul 2012 21:12:15 +0200 Subject: [PATCH] Finished CNetTokenCache Added basic support for mastersrv Conflicts: bam.lua src/engine/shared/network_client.cpp --- bam.lua | 6 +- src/engine/server/register.cpp | 8 +-- src/engine/server/register.h | 6 +- src/engine/server/server.cpp | 2 +- src/engine/shared/network.cpp | 5 +- src/engine/shared/network.h | 70 +++++++++++----------- src/engine/shared/network_client.cpp | 86 ++++++++++++++++++---------- src/engine/shared/network_server.cpp | 80 ++++++++++++++++++++++++-- src/engine/shared/network_token.cpp | 74 ++++++++++++++++++------ src/mastersrv/mastersrv.cpp | 54 ++++++++--------- 10 files changed, 266 insertions(+), 125 deletions(-) diff --git a/bam.lua b/bam.lua index c366c6e0c..f4a970383 100644 --- a/bam.lua +++ b/bam.lua @@ -344,17 +344,17 @@ function BuildTools(settings) local tools = {} for i,v in ipairs(Collect("src/tools/*.cpp", "src/tools/*.c")) do local toolname = PathFilename(PathBase(v)) - tools[i] = Link(settings, toolname, Compile(settings, v), libs["zlib"], libs["wavpack"], libs["png"]) + tools[i] = Link(settings, toolname, Compile(settings, v), libs["zlib"], libs["md5"], libs["wavpack"], libs["png"]) end PseudoTarget(settings.link.Output(settings, "pseudo_tools") .. settings.link.extension, tools) end function BuildMasterserver(settings) - return Link(settings, "mastersrv", Compile(settings, Collect("src/mastersrv/*.cpp")), libs["zlib"]) + return Link(settings, "mastersrv", Compile(settings, Collect("src/mastersrv/*.cpp")), libs["zlib"], libs["md5"]) end function BuildVersionserver(settings) - return Link(settings, "versionsrv", Compile(settings, Collect("src/versionsrv/*.cpp")), libs["zlib"]) + return Link(settings, "versionsrv", Compile(settings, Collect("src/versionsrv/*.cpp")), libs["zlib"], libs["md5"]) end function BuildContent(settings) diff --git a/src/engine/server/register.cpp b/src/engine/server/register.cpp index 9c8c17e3e..8dc4f9a75 100644 --- a/src/engine/server/register.cpp +++ b/src/engine/server/register.cpp @@ -31,7 +31,7 @@ void CRegister::RegisterNewState(int State) m_RegisterStateStart = time_get(); } -void CRegister::RegisterSendFwcheckresponse(NETADDR *pAddr) +void CRegister::RegisterSendFwcheckresponse(NETADDR *pAddr, TOKEN Token) { CNetChunk Packet; Packet.m_ClientID = -1; @@ -39,7 +39,7 @@ void CRegister::RegisterSendFwcheckresponse(NETADDR *pAddr) Packet.m_Flags = NETSENDFLAG_CONNLESS; Packet.m_DataSize = sizeof(SERVERBROWSE_FWRESPONSE); Packet.m_pData = SERVERBROWSE_FWRESPONSE; - m_pNetServer->Send(&Packet); + m_pNetServer->Send(&Packet, Token); } void CRegister::RegisterSendHeartbeat(NETADDR Addr) @@ -235,7 +235,7 @@ void CRegister::RegisterUpdate(int Nettype) } } -int CRegister::RegisterProcessPacket(CNetChunk *pPacket) +int CRegister::RegisterProcessPacket(CNetChunk *pPacket, TOKEN Token) { // check for masterserver address bool Valid = false; @@ -257,7 +257,7 @@ int CRegister::RegisterProcessPacket(CNetChunk *pPacket) if(pPacket->m_DataSize == sizeof(SERVERBROWSE_FWCHECK) && mem_comp(pPacket->m_pData, SERVERBROWSE_FWCHECK, sizeof(SERVERBROWSE_FWCHECK)) == 0) { - RegisterSendFwcheckresponse(&pPacket->m_Address); + RegisterSendFwcheckresponse(&pPacket->m_Address, Token); return 1; } else if(pPacket->m_DataSize == sizeof(SERVERBROWSE_FWOK) && diff --git a/src/engine/server/register.h b/src/engine/server/register.h index c0392380d..8699fab28 100644 --- a/src/engine/server/register.h +++ b/src/engine/server/register.h @@ -3,6 +3,8 @@ #ifndef ENGINE_SERVER_REGISTER_H #define ENGINE_SERVER_REGISTER_H +#include + class CRegister { enum @@ -36,7 +38,7 @@ class CRegister int m_RegisterRegisteredServer; void RegisterNewState(int State); - void RegisterSendFwcheckresponse(NETADDR *pAddr); + void RegisterSendFwcheckresponse(NETADDR *pAddr, TOKEN Token); void RegisterSendHeartbeat(NETADDR Addr); void RegisterSendCountRequest(NETADDR Addr); void RegisterGotCount(struct CNetChunk *pChunk); @@ -45,7 +47,7 @@ public: CRegister(); void Init(class CNetServer *pNetServer, class IEngineMasterServer *pMasterServer, class IConsole *pConsole); void RegisterUpdate(int Nettype); - int RegisterProcessPacket(struct CNetChunk *pPacket); + int RegisterProcessPacket(struct CNetChunk *pPacket, TOKEN Token); }; #endif diff --git a/src/engine/server/server.cpp b/src/engine/server/server.cpp index b7b12418e..60f4623d0 100644 --- a/src/engine/server/server.cpp +++ b/src/engine/server/server.cpp @@ -1156,7 +1156,7 @@ void CServer::PumpNetwork() { // stateless? if(!(Packet.m_Flags&NETSENDFLAG_STATELESS)) - if(m_Register.RegisterProcessPacket(&Packet)) + if(m_Register.RegisterProcessPacket(&Packet, ResponseToken)) continue; if(Packet.m_DataSize >= sizeof(SERVERBROWSE_GETINFO) && mem_comp(Packet.m_pData, SERVERBROWSE_GETINFO, sizeof(SERVERBROWSE_GETINFO)) == 0) diff --git a/src/engine/shared/network.cpp b/src/engine/shared/network.cpp index 0eb8019f5..7335885d2 100644 --- a/src/engine/shared/network.cpp +++ b/src/engine/shared/network.cpp @@ -153,8 +153,9 @@ void CNetBase::SendPacket(NETSOCKET Socket, const NETADDR *pAddr, CNetPacketCons else HeaderSize = NET_PACKETHEADERSIZE; - // compress - CompressedSize = ms_Huffman.Compress(pPacket->m_aChunkData, pPacket->m_DataSize, &aBuffer[HeaderSize], NET_MAX_PACKETSIZE - HeaderSize - 1); + // compress if not ctrl msg + if(!(pPacket->m_Flags&NET_PACKETFLAG_CONTROL)) + CompressedSize = ms_Huffman.Compress(pPacket->m_aChunkData, pPacket->m_DataSize, &aBuffer[HeaderSize], NET_MAX_PACKETSIZE - HeaderSize - 1); // check if the compression was enabled, successful and good enough if(CompressedSize > 0 && CompressedSize < pPacket->m_DataSize) diff --git a/src/engine/shared/network.h b/src/engine/shared/network.h index 3cb4af40e..b5b58f36a 100644 --- a/src/engine/shared/network.h +++ b/src/engine/shared/network.h @@ -157,13 +157,43 @@ public: }; +class CNetTokenManager +{ +public: + void Init(NETSOCKET Socket, int SeedTime = NET_SEEDTIME); + void Update(); + + void GenerateSeed(); + + int ProcessMessage(const NETADDR *pAddr, const CNetPacketConstruct *pPacket, bool Notify); + + bool CheckToken(const NETADDR *pAddr, TOKEN Token, TOKEN ResponseToken, bool Notify); + TOKEN GenerateToken(const NETADDR *pAddr) const; + static TOKEN GenerateToken(const NETADDR *pAddr, int64 Seed); + +private: + NETSOCKET m_Socket; + + int64 m_Seed; + int64 m_PrevSeed; + + TOKEN m_GlobalToken; + TOKEN m_PrevGlobalToken; + + int m_SeedTime; + int64 m_NextSeedTime; +}; + + class CNetTokenCache { public: - void Init(NETSOCKET Socket, const TOKEN *pToken); + CNetTokenCache(); + ~CNetTokenCache(); + void Init(NETSOCKET Socket, const CNetTokenManager *pTokenManager); void SendPacketConnless(const NETADDR *pAddr, const void *pData, int DataSize); void FetchToken(const NETADDR *pAddr); - void ProcessTokenMessage(const NETADDR *pAddr, TOKEN PeerToken); + void AddToken(const NETADDR *pAddr, TOKEN PeerToken); TOKEN GetToken(const NETADDR *pAddr); void Update(); @@ -192,35 +222,7 @@ private: CConnlessPacketInfo *m_pConnlessPacketList; // TODO: enhance this, dynamic linked lists // are bad for performance NETSOCKET m_Socket; - const TOKEN *m_pToken; -}; - -class CNetTokenManager -{ -public: - void Init(NETSOCKET Socket, int SeedTime = NET_SEEDTIME); - void Update(); - - void GenerateSeed(); - - int ProcessMessage(const NETADDR *pAddr, const CNetPacketConstruct *pPacket, bool Notify); - - bool CheckToken(const NETADDR *pAddr, TOKEN Token, TOKEN ResponseToken, bool Notify); - TOKEN GenerateToken(const NETADDR *pAddr); - -private: - static TOKEN GenerateToken(const NETADDR *pAddr, int64 Seed); - - NETSOCKET m_Socket; - - int64 m_Seed; - int64 m_PrevSeed; - - TOKEN m_GlobalToken; - TOKEN m_PrevGlobalToken; - - int m_SeedTime; - int64 m_NextSeedTime; + const CNetTokenManager *m_pTokenManager; }; @@ -373,6 +375,7 @@ class CNetServer CNetRecvUnpacker m_RecvUnpacker; CNetTokenManager m_TokenManager; + CNetTokenCache m_TokenCache; int m_Flags; public: @@ -418,8 +421,6 @@ class CNetConsole CNetRecvUnpacker m_RecvUnpacker; - CNetTokenManager m_TokenManager; - public: void SetCallbacks(NETFUNC_NEWCLIENT pfnNewClient, NETFUNC_DELCLIENT pfnDelClient, void *pUser); @@ -448,7 +449,10 @@ class CNetClient { CNetConnection m_Connection; CNetRecvUnpacker m_RecvUnpacker; + + CNetTokenCache m_TokenCache; CNetTokenManager m_TokenManager; + NETSOCKET m_Socket; int m_Flags; public: diff --git a/src/engine/shared/network_client.cpp b/src/engine/shared/network_client.cpp index 8571e9f39..21afa74b8 100644 --- a/src/engine/shared/network_client.cpp +++ b/src/engine/shared/network_client.cpp @@ -19,6 +19,7 @@ bool CNetClient::Open(NETADDR BindAddr, int Flags) m_Connection.Init(m_Socket, false); m_TokenManager.Init(Socket); + m_TokenCache.Init(Socket, &m_TokenManager); m_Flags = Flags; @@ -81,6 +82,7 @@ int CNetClient::Recv(CNetChunk *pChunk, TOKEN *pResponseToken, int *pVersion) if(net_addr_comp(m_Connection.PeerAddress(), &Addr) == 0) { if(m_Connection.State() != NET_CONNSTATE_OFFLINE && m_Connection.State() != NET_CONNSTATE_ERROR && m_Connection.Feed(&m_RecvUnpacker.m_Data, &Addr)) + { if(!(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONNLESS)) m_RecvUnpacker.Start(&Addr, &m_Connection, 0); else @@ -90,34 +92,44 @@ int CNetClient::Recv(CNetChunk *pChunk, TOKEN *pResponseToken, int *pVersion) pChunk->m_pData = m_RecvUnpacker.m_Data.m_aChunkData; return 1; } - } - else if(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONNLESS) - { - if(!(m_Flags&NETFLAG_ALLOWOLDSTYLE) && m_RecvUnpacker.m_Data.m_Version != NET_PACKETVERSION) - continue; - - int Accept = m_TokenManager.ProcessMessage(&Addr, &m_RecvUnpacker.m_Data, false); - if(!Accept) - continue; - - pChunk->m_Flags = NETSENDFLAG_CONNLESS; - - if(Accept < 0) - { - if(!(m_Flags&NETFLAG_ALLOWSTATELESS)) - continue; - pChunk->m_Flags |= NETSENDFLAG_STATELESS; } - pChunk->m_ClientID = -1; - pChunk->m_Address = Addr; - pChunk->m_DataSize = m_RecvUnpacker.m_Data.m_DataSize; - pChunk->m_pData = m_RecvUnpacker.m_Data.m_aChunkData; - if(pVersion) - *pVersion = m_RecvUnpacker.m_Data.m_Version; - if(pResponseToken) - *pResponseToken = m_RecvUnpacker.m_Data.m_ResponseToken; - return 1; + } + else + { + int Accept = m_TokenManager.ProcessMessage(&Addr, &m_RecvUnpacker.m_Data, false); + if(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONTROL) + { + if(m_RecvUnpacker.m_Data.m_aChunkData[0] == NET_CTRLMSG_TOKEN) + m_TokenCache.AddToken(&Addr, m_RecvUnpacker.m_Data.m_ResponseToken); + } + else if(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONNLESS) + { + if(!(m_Flags&NETFLAG_ALLOWOLDSTYLE) && m_RecvUnpacker.m_Data.m_Version != NET_PACKETVERSION) + continue; + + if(!Accept) + continue; + + pChunk->m_Flags = NETSENDFLAG_CONNLESS; + + if(Accept < 0) + { + if(!(m_Flags&NETFLAG_ALLOWSTATELESS)) + continue; + pChunk->m_Flags |= NETSENDFLAG_STATELESS; + } + pChunk->m_ClientID = -1; + pChunk->m_Address = Addr; + pChunk->m_DataSize = m_RecvUnpacker.m_Data.m_DataSize; + pChunk->m_pData = m_RecvUnpacker.m_Data.m_aChunkData; + + if(pVersion) + *pVersion = m_RecvUnpacker.m_Data.m_Version; + if(pResponseToken) + *pResponseToken = m_RecvUnpacker.m_Data.m_ResponseToken; + return 1; + } } } } @@ -134,13 +146,26 @@ int CNetClient::Send(CNetChunk *pChunk, TOKEN Token, int Version) return -1; } - // send connectionless packet - if(pChunk->m_ClientID == -1) + if(pChunk->m_Flags&NETSENDFLAG_STATELESS || Token != NET_TOKEN_NONE) + { + if(pChunk->m_Flags&NETSENDFLAG_STATELESS) + { + dbg_assert(pChunk->m_ClientID == -1, "errornous client id, connless packets can only be sent to cid=-1"); + dbg_assert(Token == NET_TOKEN_NONE, "stateless packets can't have a token"); + } CNetBase::SendPacketConnless(m_Socket, &pChunk->m_Address, Version, Token, m_TokenManager.GenerateToken(&pChunk->m_Address), pChunk->m_pData, pChunk->m_DataSize); + } else { - dbg_assert(pChunk->m_ClientID == 0, "errornous client id"); - m_Connection.SendPacketConnless((const char *)pChunk->m_pData, pChunk->m_DataSize); + if(pChunk->m_ClientID == -1) + { + m_TokenCache.SendPacketConnless(&pChunk->m_Address, pChunk->m_pData, pChunk->m_DataSize); + } + else + { + dbg_assert(pChunk->m_ClientID == 0, "errornous client id"); + m_Connection.SendPacketConnless((const char *)pChunk->m_pData, pChunk->m_DataSize); + } } } else @@ -188,3 +213,4 @@ const char *CNetClient::ErrorString() const { return m_Connection.ErrorString(); } + diff --git a/src/engine/shared/network_server.cpp b/src/engine/shared/network_server.cpp index cf1b02031..5ef147e96 100644 --- a/src/engine/shared/network_server.cpp +++ b/src/engine/shared/network_server.cpp @@ -19,6 +19,7 @@ bool CNetServer::Open(NETADDR BindAddr, CNetBan *pNetBan, int MaxClients, int Ma return false; m_TokenManager.Init(m_Socket); + m_TokenCache.Init(m_Socket, &m_TokenManager); m_pNetBan = pNetBan; @@ -90,6 +91,7 @@ int CNetServer::Update() } m_TokenManager.Update(); + m_TokenCache.Update(); return 0; } @@ -156,7 +158,59 @@ int CNetServer::Recv(CNetChunk *pChunk, TOKEN *pResponseToken, int *pVersion) if(!Accept) continue; - if(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONNLESS) + if(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONTROL) + { + if(m_RecvUnpacker.m_Data.m_aChunkData[0] == NET_CTRLMSG_CONNECT) + { + bool Found = false; + + // only allow a specific number of players with the same ip + NETADDR ThisAddr = Addr, OtherAddr; + int FoundAddr = 1; + ThisAddr.port = 0; + for(int i = 0; i < MaxClients(); i++) + { + if(m_aSlots[i].m_Connection.State() == NET_CONNSTATE_OFFLINE) + continue; + + OtherAddr = *m_aSlots[i].m_Connection.PeerAddress(); + OtherAddr.port = 0; + if(!net_addr_comp(&ThisAddr, &OtherAddr)) + { + if(FoundAddr++ >= m_MaxClientsPerIP) + { + char aBuf[128]; + str_format(aBuf, sizeof(aBuf), "Only %d players with the same IP are allowed", m_MaxClientsPerIP); + CNetBase::SendControlMsg(m_Socket, &Addr, m_RecvUnpacker.m_Data.m_Version, m_RecvUnpacker.m_Data.m_ResponseToken, 0, NET_CTRLMSG_CLOSE, aBuf, sizeof(aBuf)); + return 0; + } + } + } + + for(int i = 0; i < MaxClients(); i++) + { + if(m_aSlots[i].m_Connection.State() == NET_CONNSTATE_OFFLINE) + { + Found = true; + m_aSlots[i].m_Connection.SetToken(m_RecvUnpacker.m_Data.m_Token); + m_aSlots[i].m_Connection.Feed(&m_RecvUnpacker.m_Data, &Addr); + m_aSlots[i].m_Connection.SetToken(m_RecvUnpacker.m_Data.m_Token); // HACK! + if(m_pfnNewClient) + m_pfnNewClient(i, m_UserPtr); + break; + } + } + + if(!Found) + { + const char FullMsg[] = "This server is full"; + CNetBase::SendControlMsg(m_Socket, &Addr, m_RecvUnpacker.m_Data.m_Version, m_RecvUnpacker.m_Data.m_ResponseToken, 0, NET_CTRLMSG_CLOSE, FullMsg, sizeof(FullMsg)); + } + } + else if(m_RecvUnpacker.m_Data.m_aChunkData[0] == NET_CTRLMSG_TOKEN) + m_TokenCache.AddToken(&Addr, m_RecvUnpacker.m_Data.m_ResponseToken); + } + else if(m_RecvUnpacker.m_Data.m_Flags&NET_PACKETFLAG_CONNLESS) { if(!(m_Flags&NETFLAG_ALLOWOLDSTYLE) && m_RecvUnpacker.m_Data.m_Version != NET_PACKETVERSION) continue; @@ -243,14 +297,28 @@ int CNetServer::Send(CNetChunk *pChunk, TOKEN Token, int Version) return -1; } - // send connectionless packet - if(pChunk->m_ClientID == -1) + if(pChunk->m_Flags&NETSENDFLAG_STATELESS || Token != NET_TOKEN_NONE) + { + if(pChunk->m_Flags&NETSENDFLAG_STATELESS) + { + dbg_assert(pChunk->m_ClientID == -1, "errornous client id, connless packets can only be sent to cid=-1"); + dbg_assert(Token == NET_TOKEN_NONE, "stateless packets can't have a token"); + } CNetBase::SendPacketConnless(m_Socket, &pChunk->m_Address, Version, Token, m_TokenManager.GenerateToken(&pChunk->m_Address), pChunk->m_pData, pChunk->m_DataSize); + } else { - dbg_assert(pChunk->m_ClientID >= 0, "errornous client id"); - dbg_assert(pChunk->m_ClientID < MaxClients(), "errornous client id"); - m_aSlots[pChunk->m_ClientID].m_Connection.SendPacketConnless((const char *)pChunk->m_pData, pChunk->m_DataSize); + if(pChunk->m_ClientID == -1) + { + m_TokenCache.SendPacketConnless(&pChunk->m_Address, pChunk->m_pData, pChunk->m_DataSize); + } + else + { + dbg_assert(pChunk->m_ClientID >= 0, "errornous client id"); + dbg_assert(pChunk->m_ClientID < MaxClients(), "errornous client id"); + + m_aSlots[pChunk->m_ClientID].m_Connection.SendPacketConnless((const char *)pChunk->m_pData, pChunk->m_DataSize); + } } } else diff --git a/src/engine/shared/network_token.cpp b/src/engine/shared/network_token.cpp index bb7c7bb42..457d13763 100644 --- a/src/engine/shared/network_token.cpp +++ b/src/engine/shared/network_token.cpp @@ -36,7 +36,8 @@ void CNetTokenManager::Update() int CNetTokenManager::ProcessMessage(const NETADDR *pAddr, const CNetPacketConstruct *pPacket, bool Notify) { - if(pPacket->m_Token != NET_TOKEN_NONE && !CheckToken(pAddr, pPacket->m_Token, pPacket->m_ResponseToken, Notify)) + if(pPacket->m_Token != NET_TOKEN_NONE + && !CheckToken(pAddr, pPacket->m_Token, pPacket->m_ResponseToken, Notify)) return 0; // wrong token, silent ignore bool Verified = pPacket->m_Token != NET_TOKEN_NONE; @@ -57,10 +58,12 @@ int CNetTokenManager::ProcessMessage(const NETADDR *pAddr, const CNetPacketConst } if(Verified && TokenMessage) - return 0; // everything is fine, token exchange complete + return 1; // everything is fine, token exchange complete // client requesting token - CNetBase::SendControlMsgWithToken(m_Socket, (NETADDR *)pAddr, pPacket->m_ResponseToken, 0, NET_CTRLMSG_TOKEN, GenerateToken(pAddr)); + CNetBase::SendControlMsgWithToken(m_Socket, (NETADDR *)pAddr, + pPacket->m_ResponseToken, 0, NET_CTRLMSG_TOKEN, + GenerateToken(pAddr)); return 0; // no need to process NET_CTRLMSG_TOKEN further } @@ -83,7 +86,7 @@ void CNetTokenManager::GenerateSeed() m_NextSeedTime = time_get() + time_freq() * m_SeedTime; } -TOKEN CNetTokenManager::GenerateToken(const NETADDR *pAddr) +TOKEN CNetTokenManager::GenerateToken(const NETADDR *pAddr) const { return GenerateToken(pAddr, m_Seed); } @@ -91,13 +94,16 @@ TOKEN CNetTokenManager::GenerateToken(const NETADDR *pAddr) TOKEN CNetTokenManager::GenerateToken(const NETADDR *pAddr, int64 Seed) { static const NETADDR NullAddr = { 0 }; + NETADDR Addr; char aBuf[sizeof(NETADDR) + sizeof(int64)]; int Result; if(pAddr->type & NETTYPE_LINK_BROADCAST) return GenerateToken(&NullAddr, Seed); - mem_copy(aBuf, pAddr, sizeof(NETADDR)); + Addr = *pAddr; + Addr.port = 0; + mem_copy(aBuf, &Addr, sizeof(NETADDR)); mem_copy(aBuf + sizeof(NETADDR), &Seed, sizeof(int64)); Result = Hash(aBuf, sizeof(aBuf)); @@ -116,7 +122,9 @@ bool CNetTokenManager::CheckToken(const NETADDR *pAddr, TOKEN Token, TOKEN Respo if(GenerateToken(pAddr, m_PrevSeed) == Token) { if(Notify) - CNetBase::SendControlMsgWithToken(m_Socket, (NETADDR *)pAddr, ResponseToken, 0, NET_CTRLMSG_TOKEN, CurrentToken); // notify the peer about the new token + CNetBase::SendControlMsgWithToken(m_Socket, (NETADDR *)pAddr, + ResponseToken, 0, NET_CTRLMSG_TOKEN, CurrentToken); + // notify the peer about the new token return true; } else if(Token == m_GlobalToken) @@ -124,18 +132,41 @@ bool CNetTokenManager::CheckToken(const NETADDR *pAddr, TOKEN Token, TOKEN Respo else if(Token == m_PrevGlobalToken) { if(Notify) - CNetBase::SendControlMsgWithToken(m_Socket, (NETADDR *)pAddr, ResponseToken, 0, NET_CTRLMSG_TOKEN, m_GlobalToken); // notify the peer about the new token + CNetBase::SendControlMsgWithToken(m_Socket, (NETADDR *)pAddr, + ResponseToken, 0, NET_CTRLMSG_TOKEN, m_GlobalToken); + // notify the peer about the new token return true; } return false; } -void CNetTokenCache::Init(NETSOCKET Socket, const TOKEN *pToken) + +CNetTokenCache::CNetTokenCache() { + m_pTokenManager = 0; + m_pConnlessPacketList = 0; +} + +CNetTokenCache::~CNetTokenCache() +{ + // delete the linked list + while(m_pConnlessPacketList) + { + CConnlessPacketInfo *pTemp = m_pConnlessPacketList->m_pNext; + delete m_pConnlessPacketList; + m_pConnlessPacketList = pTemp; + } +} + +void CNetTokenCache::Init(NETSOCKET Socket, const CNetTokenManager *pTokenManager) +{ + // call the destructor to clear the linked list + this->~CNetTokenCache(); + m_TokenCache.Init(); m_Socket = Socket; - m_pToken = pToken; + m_pTokenManager = pTokenManager; } void CNetTokenCache::SendPacketConnless(const NETADDR *pAddr, const void *pData, int DataSize) @@ -143,7 +174,8 @@ void CNetTokenCache::SendPacketConnless(const NETADDR *pAddr, const void *pData, TOKEN Token = GetToken(pAddr); if(Token != NET_TOKEN_NONE) { - CNetBase::SendPacketConnless(m_Socket, pAddr, NET_PACKETVERSION, Token, *m_pToken, pData, DataSize); + CNetBase::SendPacketConnless(m_Socket, pAddr, NET_PACKETVERSION, + Token, m_pTokenManager->GenerateToken(pAddr), pData, DataSize); } else { @@ -155,6 +187,7 @@ void CNetTokenCache::SendPacketConnless(const NETADDR *pAddr, const void *pData, ppInfo = &(*ppInfo)->m_pNext; *ppInfo = new CConnlessPacketInfo(); mem_copy((*ppInfo)->m_aData, pData, DataSize); + (*ppInfo)->m_Addr = *pAddr; (*ppInfo)->m_DataSize = DataSize; (*ppInfo)->m_Expiry = time_get() + time_freq() * NET_TOKENCACHE_PACKETEXPIRY; } @@ -174,13 +207,15 @@ TOKEN CNetTokenCache::GetToken(const NETADDR *pAddr) void CNetTokenCache::FetchToken(const NETADDR *pAddr) { - CNetBase::SendControlMsgWithToken(m_Socket, pAddr, NET_TOKEN_NONE, 0, NET_CTRLMSG_TOKEN, *m_pToken); + CNetBase::SendControlMsgWithToken(m_Socket, pAddr, NET_TOKEN_NONE, 0, + NET_CTRLMSG_TOKEN, m_pTokenManager->GenerateToken(pAddr)); } -void CNetTokenCache::ProcessTokenMessage(const NETADDR *pAddr, TOKEN Token) +void CNetTokenCache::AddToken(const NETADDR *pAddr, TOKEN Token) { if(Token == NET_TOKEN_NONE) return; + CAddressInfo Info; Info.m_Addr = *pAddr; Info.m_Token = Token; @@ -197,7 +232,10 @@ void CNetTokenCache::ProcessTokenMessage(const NETADDR *pAddr, TOKEN Token) { if(net_addr_comp(&pInfo->m_Addr, pAddr) == 0) { - CNetBase::SendPacketConnless(m_Socket, pAddr, NET_PACKETVERSION, Token, *m_pToken, pInfo->m_aData, pInfo->m_DataSize); + CNetBase::SendPacketConnless(m_Socket, pAddr, + NET_PACKETVERSION, Token, + m_pTokenManager->GenerateToken(pAddr), + pInfo->m_aData, pInfo->m_DataSize); *ppPrevNext = pInfo->m_pNext; delete pInfo; pInfo = *ppPrevNext; @@ -218,12 +256,12 @@ void CNetTokenCache::Update() // drop expired packets - CConnlessPacketInfo *pInfo = m_pConnlessPacketList; - while(pInfo && pInfo->m_Expiry <= Now) + while(m_pConnlessPacketList && m_pConnlessPacketList->m_Expiry <= Now) { - m_pConnlessPacketList = pInfo->m_pNext; - delete pInfo; - pInfo = m_pConnlessPacketList; + CConnlessPacketInfo *pNewList; + pNewList = m_pConnlessPacketList->m_pNext; + delete m_pConnlessPacketList; + m_pConnlessPacketList = pNewList; } } diff --git a/src/mastersrv/mastersrv.cpp b/src/mastersrv/mastersrv.cpp index 60feba96f..482471b79 100644 --- a/src/mastersrv/mastersrv.cpp +++ b/src/mastersrv/mastersrv.cpp @@ -29,6 +29,7 @@ struct CCheckServer NETADDR m_AltAddress; int m_TryCount; int64 m_TryTime; + TOKEN m_Token; }; static CCheckServer m_aCheckServers[MAX_SERVERS]; @@ -168,7 +169,7 @@ void BuildPackets() } } -void SendOk(NETADDR *pAddr) +void SendOk(NETADDR *pAddr, TOKEN Token) { CNetChunk p; p.m_ClientID = -1; @@ -178,11 +179,11 @@ void SendOk(NETADDR *pAddr) p.m_pData = SERVERBROWSE_FWOK; // send on both to be sure - m_NetChecker.Send(&p); - m_NetOp.Send(&p); + m_NetChecker.Send(&p, Token); + m_NetOp.Send(&p, Token); } -void SendError(NETADDR *pAddr) +void SendError(NETADDR *pAddr, TOKEN Token) { CNetChunk p; p.m_ClientID = -1; @@ -190,10 +191,10 @@ void SendError(NETADDR *pAddr) p.m_Flags = NETSENDFLAG_CONNLESS; p.m_DataSize = sizeof(SERVERBROWSE_FWERROR); p.m_pData = SERVERBROWSE_FWERROR; - m_NetOp.Send(&p); + m_NetOp.Send(&p, Token); } -void SendCheck(NETADDR *pAddr) +void SendCheck(NETADDR *pAddr, TOKEN Token) { CNetChunk p; p.m_ClientID = -1; @@ -201,10 +202,10 @@ void SendCheck(NETADDR *pAddr) p.m_Flags = NETSENDFLAG_CONNLESS; p.m_DataSize = sizeof(SERVERBROWSE_FWCHECK); p.m_pData = SERVERBROWSE_FWCHECK; - m_NetChecker.Send(&p); + m_NetChecker.Send(&p, Token); } -void AddCheckserver(NETADDR *pInfo, NETADDR *pAlt, ServerType Type) +void AddCheckserver(NETADDR *pInfo, NETADDR *pAlt, ServerType Type, TOKEN Token) { // add server if(m_NumCheckServers == MAX_SERVERS) @@ -223,6 +224,7 @@ void AddCheckserver(NETADDR *pInfo, NETADDR *pAlt, ServerType Type) m_aCheckServers[m_NumCheckServers].m_TryCount = 0; m_aCheckServers[m_NumCheckServers].m_TryTime = 0; m_aCheckServers[m_NumCheckServers].m_Type = Type; + m_aCheckServers[m_NumCheckServers].m_Token = Token; m_NumCheckServers++; } @@ -274,7 +276,7 @@ void UpdateServers() dbg_msg("mastersrv", "check failed: %s (%s)", aAddrStr, aAltAddrStr); // FAIL!! - SendError(&m_aCheckServers[i].m_Address); + SendError(&m_aCheckServers[i].m_Address, m_aCheckServers[i].m_Token); m_aCheckServers[i] = m_aCheckServers[m_NumCheckServers-1]; m_NumCheckServers--; i--; @@ -284,9 +286,9 @@ void UpdateServers() m_aCheckServers[i].m_TryCount++; m_aCheckServers[i].m_TryTime = Now; if(m_aCheckServers[i].m_TryCount&1) - SendCheck(&m_aCheckServers[i].m_Address); + SendCheck(&m_aCheckServers[i].m_Address, m_aCheckServers[i].m_Token); else - SendCheck(&m_aCheckServers[i].m_AltAddress); + SendCheck(&m_aCheckServers[i].m_AltAddress, m_aCheckServers[i].m_Token); } } } @@ -361,13 +363,13 @@ int main(int argc, const char **argv) // ignore_convention BindAddr.port = MASTERSERVER_PORT; } - if(!m_NetOp.Open(BindAddr, 0)) + if(!m_NetOp.Open(BindAddr, NETFLAG_ALLOWSTATELESS)) { dbg_msg("mastersrv", "couldn't start network (op)"); return -1; } BindAddr.port = MASTERSERVER_PORT+1; - if(!m_NetChecker.Open(BindAddr, 0)) + if(!m_NetChecker.Open(BindAddr, NETFLAG_ALLOWSTATELESS)) { dbg_msg("mastersrv", "couldn't start network (checker)"); return -1; @@ -385,7 +387,8 @@ int main(int argc, const char **argv) // ignore_convention // process m_aPackets CNetChunk Packet; - while(m_NetOp.Recv(&Packet)) + TOKEN Token; + while(m_NetOp.Recv(&Packet, &Token)) { // check if the server is banned if(m_NetBan.IsBanned(&Packet.m_Address, 0, 0)) @@ -402,9 +405,9 @@ int main(int argc, const char **argv) // ignore_convention d[sizeof(SERVERBROWSE_HEARTBEAT)+1]; // add it - AddCheckserver(&Packet.m_Address, &Alt, SERVERTYPE_NORMAL); + AddCheckserver(&Packet.m_Address, &Alt, SERVERTYPE_NORMAL, Token); } - else if(Packet.m_DataSize == sizeof(SERVERBROWSE_HEARTBEAT_LEGACY)+2 && + /*else if(Packet.m_DataSize == sizeof(SERVERBROWSE_HEARTBEAT_LEGACY)+2 && mem_comp(Packet.m_pData, SERVERBROWSE_HEARTBEAT_LEGACY, sizeof(SERVERBROWSE_HEARTBEAT_LEGACY)) == 0) { NETADDR Alt; @@ -416,8 +419,7 @@ int main(int argc, const char **argv) // ignore_convention // add it AddCheckserver(&Packet.m_Address, &Alt, SERVERTYPE_LEGACY); - } - + }*/ else if(Packet.m_DataSize == sizeof(SERVERBROWSE_GETCOUNT) && mem_comp(Packet.m_pData, SERVERBROWSE_GETCOUNT, sizeof(SERVERBROWSE_GETCOUNT)) == 0) { @@ -431,9 +433,9 @@ int main(int argc, const char **argv) // ignore_convention p.m_pData = &m_CountData; m_CountData.m_High = (m_NumServers>>8)&0xff; m_CountData.m_Low = m_NumServers&0xff; - m_NetOp.Send(&p); + m_NetOp.Send(&p, Token); } - else if(Packet.m_DataSize == sizeof(SERVERBROWSE_GETCOUNT_LEGACY) && + /*else if(Packet.m_DataSize == sizeof(SERVERBROWSE_GETCOUNT_LEGACY) && mem_comp(Packet.m_pData, SERVERBROWSE_GETCOUNT_LEGACY, sizeof(SERVERBROWSE_GETCOUNT_LEGACY)) == 0) { dbg_msg("mastersrv", "count requested, responding with %d", m_NumServers); @@ -447,7 +449,7 @@ int main(int argc, const char **argv) // ignore_convention m_CountDataLegacy.m_High = (m_NumServers>>8)&0xff; m_CountDataLegacy.m_Low = m_NumServers&0xff; m_NetOp.Send(&p); - } + }*/ else if(Packet.m_DataSize == sizeof(SERVERBROWSE_GETLIST) && mem_comp(Packet.m_pData, SERVERBROWSE_GETLIST, sizeof(SERVERBROWSE_GETLIST)) == 0) { @@ -463,10 +465,10 @@ int main(int argc, const char **argv) // ignore_convention { p.m_DataSize = m_aPackets[i].m_Size; p.m_pData = &m_aPackets[i].m_Data; - m_NetOp.Send(&p); + m_NetOp.Send(&p, Token); } } - else if(Packet.m_DataSize == sizeof(SERVERBROWSE_GETLIST_LEGACY) && + /*else if(Packet.m_DataSize == sizeof(SERVERBROWSE_GETLIST_LEGACY) && mem_comp(Packet.m_pData, SERVERBROWSE_GETLIST_LEGACY, sizeof(SERVERBROWSE_GETLIST_LEGACY)) == 0) { // someone requested the list @@ -483,11 +485,11 @@ int main(int argc, const char **argv) // ignore_convention p.m_pData = &m_aPacketsLegacy[i].m_Data; m_NetOp.Send(&p); } - } + }*/ } // process m_aPackets - while(m_NetChecker.Recv(&Packet)) + while(m_NetChecker.Recv(&Packet, &Token)) { // check if the server is banned if(m_NetBan.IsBanned(&Packet.m_Address, 0, 0)) @@ -515,7 +517,7 @@ int main(int argc, const char **argv) // ignore_convention continue; AddServer(&Packet.m_Address, Type); - SendOk(&Packet.m_Address); + SendOk(&Packet.m_Address, Token); } }