/* * Copyright (c) 2023, The PurpleI2P Project * * This file is part of Purple i2pd project and licensed under BSD3 * * See full license text in LICENSE file at top of project tree * */ #include "KadDHT.h" namespace i2p { namespace data { DHTNode::DHTNode (): zero (nullptr), one (nullptr) { } DHTNode::~DHTNode () { if (zero) delete zero; if (one) delete one; } void DHTNode::MoveRouterUp (bool fromOne) { DHTNode *& side = fromOne ? one : zero; if (side) { if (router) router = nullptr; // shouldn't happen router = side->router; side->router = nullptr; delete side; side = nullptr; } } DHTTable::DHTTable (): m_Size (0) { m_Root = new DHTNode; } DHTTable::~DHTTable () { delete m_Root; } void DHTTable::Clear () { m_Size = 0; delete m_Root; m_Root = new DHTNode; } void DHTTable::Insert (const std::shared_ptr& r) { if (!r) return; return Insert (r, m_Root, 0); } void DHTTable::Insert (const std::shared_ptr& r, DHTNode * root, int level) { if (root->router) { if (root->router->GetIdentHash () == r->GetIdentHash ()) { root->router = r; // replace return; } auto r2 = root->router; root->router = nullptr; m_Size--; int bit1, bit2; do { bit1 = r->GetIdentHash ().GetBit (level); bit2 = r2->GetIdentHash ().GetBit (level); if (bit1 == bit2) { if (bit1) { if (root->one) return; // something wrong root->one = new DHTNode; root = root->one; } else { if (root->zero) return; // something wrong root->zero = new DHTNode; root = root->zero; } level++; } } while (bit1 == bit2); if (!root->zero) root->zero = new DHTNode; if (!root->one) root->one = new DHTNode; if (bit1) { Insert (r2, root->zero, level + 1); Insert (r, root->one, level + 1); } else { Insert (r2, root->one, level + 1); Insert (r, root->zero, level + 1); } } else { if (!root->zero && !root->one) { root->router = r; m_Size++; return; } int bit = r->GetIdentHash ().GetBit (level); if (bit) { if (!root->one) root->one = new DHTNode; Insert (r, root->one, level + 1); } else { if (!root->zero) root->zero = new DHTNode; Insert (r, root->zero, level + 1); } } } bool DHTTable::Remove (const IdentHash& h) { return Remove (h, m_Root, 0); } bool DHTTable::Remove (const IdentHash& h, DHTNode * root, int level) { if (root) { if (root->router && root->router->GetIdentHash () == h) { root->router = nullptr; m_Size--; return true; } int bit = h.GetBit (level); if (bit) { if (root->one && Remove (h, root->one, level + 1)) { if (root->one->IsEmpty ()) { delete root->one; root->one = nullptr; if (root->zero && root->zero->router) root->MoveRouterUp (false); } else if (root->one->router && !root->zero) root->MoveRouterUp (true); return true; } } else { if (root->zero && Remove (h, root->zero, level + 1)) { if (root->zero->IsEmpty ()) { delete root->zero; root->zero = nullptr; if (root->one && root->one->router) root->MoveRouterUp (true); } else if (root->zero->router && !root->one) root->MoveRouterUp (false); return true; } } } return false; } std::shared_ptr DHTTable::FindClosest (const IdentHash& h, const Filter& filter) const { if (filter) m_Filter = filter; auto r = FindClosest (h, m_Root, 0); m_Filter = nullptr; return r; } std::shared_ptr DHTTable::FindClosest (const IdentHash& h, DHTNode * root, int level) const { bool split = false; do { if (root->router) return (!m_Filter || m_Filter (root->router)) ? root->router : nullptr; split = root->zero && root->one; if (!split) { if (root->zero) root = root->zero; else if (root->one) root = root->one; else return nullptr; level++; } } while (!split); int bit = h.GetBit (level); if (bit) { if (root->one) { auto r = FindClosest (h, root->one, level + 1); if (r) return r; } if (root->zero) { auto r = FindClosest (h, root->zero, level + 1); if (r) return r; } } else { if (root->zero) { auto r = FindClosest (h, root->zero, level + 1); if (r) return r; } if (root->one) { auto r = FindClosest (h, root->one, level + 1); if (r) return r; } } return nullptr; } std::vector > DHTTable::FindClosest (const IdentHash& h, size_t num, const Filter& filter) const { std::vector > vec; if (num > 0) { if (filter) m_Filter = filter; FindClosest (h, num, m_Root, 0, vec); m_Filter = nullptr; } return vec; } void DHTTable::FindClosest (const IdentHash& h, size_t num, DHTNode * root, int level, std::vector >& hashes) const { if (hashes.size () >= num) return; bool split = false; do { if (root->router) { if (!m_Filter || m_Filter (root->router)) hashes.push_back (root->router); return; } split = root->zero && root->one; if (!split) { if (root->zero) root = root->zero; else if (root->one) root = root->one; else return; level++; } } while (!split); int bit = h.GetBit (level); if (bit) { if (root->one) FindClosest (h, num, root->one, level + 1, hashes); if (hashes.size () < num && root->zero) FindClosest (h, num, root->zero, level + 1, hashes); } else { if (root->zero) FindClosest (h, num, root->zero, level + 1, hashes); if (hashes.size () < num && root->one) FindClosest (h, num, root->one, level + 1, hashes); } } void DHTTable::Cleanup (const Filter& filter) { if (filter) { m_Filter = filter; Cleanup (m_Root); m_Filter = nullptr; } else Clear (); } void DHTTable::Cleanup (DHTNode * root) { if (!root) return; if (root->router) { if (!m_Filter || !m_Filter (root->router)) { m_Size--; root->router = nullptr; } return; } if (root->zero) { Cleanup (root->zero); if (root->zero->IsEmpty ()) { delete root->zero; root->zero = nullptr; } } if (root->one) { Cleanup (root->one); if (root->one->IsEmpty ()) { delete root->one; root->one = nullptr; if (root->zero && root->zero->router) root->MoveRouterUp (false); } else if (root->one->router && !root->zero) root->MoveRouterUp (true); } } void DHTTable::Print (std::stringstream& s) { Print (s, m_Root, 0); } void DHTTable::Print (std::stringstream& s, DHTNode * root, int level) { if (!root) return; s << std::string (level, '-'); if (root->router) { if (!root->zero && !root->one) s << '>' << GetIdentHashAbbreviation (root->router->GetIdentHash ()); else s << "error"; } s << std::endl; if (root->zero) { s << std::string (level, '-') << "0" << std::endl; Print (s, root->zero, level + 1); } if (root->one) { s << std::string (level, '-') << "1" << std::endl; Print (s, root->one, level + 1); } } } }