// Copyright 2025 Global Phasing Ltd.
//
// AcedrgTables - COD/CSD-based atom classification and restraint value lookup
// Port of AceDRG codClassify system to gemmi.

#ifndef GEMMI_ACEDRG_TABLES_HPP_
#define GEMMI_ACEDRG_TABLES_HPP_

#include <string>
#include <tuple>
#include <vector>
#include <array>
#include <set>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <cmath>
#include <algorithm>
#include "chemcomp.hpp"
#include "elem.hpp"
#include "fail.hpp"

namespace gemmi {

// Hybridization states used in atom classification
enum class Hybridization {
  SP1,    // sp hybridization (linear)
  SP2,    // sp2 hybridization (trigonal planar)
  SP3,    // sp3 hybridization (tetrahedral)
  SPD5,   // d-orbital involvement (5-coordinate)
  SPD6,   // d-orbital involvement (6-coordinate)
  SPD7,   // d-orbital involvement (7-coordinate)
  SPD8,   // d-orbital involvement (8-coordinate)
  SP_NON  // non-standard/unknown hybridization
};

GEMMI_DLL const char* hybridization_to_string(Hybridization h);
GEMMI_DLL Hybridization hybridization_from_string(const std::string& s);

// Metal coordination geometry types
enum class CoordGeometry {
  LINEAR,           // CN=2: 180°
  TRIGONAL_PLANAR,  // CN=3: 120°
  T_SHAPED,         // CN=3: 90°, 180°
  TETRAHEDRAL,      // CN=4: 109.47°
  SQUARE_PLANAR,    // CN=4: 90°, 180°
  TRIGONAL_BIPYRAMIDAL, // CN=5: 90°, 120°, 180°
  SQUARE_PYRAMIDAL, // CN=5: 90°, 180°
  OCTAHEDRAL,       // CN=6: 90°
  TRIGONAL_PRISM,   // CN=6: alternative
  PENTAGONAL_BIPYRAMIDAL, // CN=7
  CAPPED_OCTAHEDRAL,      // CN=7
  SQUARE_ANTIPRISM,       // CN=8
  UNKNOWN
};

// Classification information for a single atom
struct CodAtomInfo {
  int index;                // Index in ChemComp.atoms
  std::string id;           // Atom id (name)
  int hashing_value;        // 0-1000+ hash code
  Element el;               // Element
  Hybridization hybrid;     // Hybridization state
  std::string cod_class;    // Full COD class (e.g., "C[6a](C[6a]C[6a])(C[6a])(H)")
  std::string cod_class_no_charge;  // COD class computed without formal charges (for COD table lookup)
  std::string cod_main;     // COD main type (codAtmMain)
  std::string cod_root;     // COD root type (codAtmRoot)
  std::string nb_symb;      // codNBSymb
  std::string nb2_symb;     // codNB2Symb
  std::string nb3_symb;     // codNB3Symb
  std::string nb1nb2_sp;    // codNB1NB2_SP
  std::vector<int> conn_atoms_no_metal; // Non-metal neighbors (index list)
  int connectivity;         // Number of bonded atoms
  int metal_connectivity;   // Number of metal neighbors
  int min_ring_size;        // Minimum ring size (0 = not in ring)
  bool is_aromatic;         // In aromatic ring
  bool is_metal;            // Is a metal atom
  int excess_electrons;     // Formal charge/lone pair info
  float charge;             // Formal/partial charge
  float par_charge;         // Partial charge (AceDRG parCharge)
  int bonding_idx;          // AceDRG bonding index (1=sp1,2=sp2,3=sp3,...)

  // Ring bookkeeping (AceDRG-style).
  std::map<std::string, int> ring_rep;
  std::map<std::string, std::string> ring_rep_s;
  std::vector<int> in_rings;

  CodAtomInfo()
    : index(-1), hashing_value(0), el(El::X), hybrid(Hybridization::SP_NON),
      connectivity(0), metal_connectivity(0), min_ring_size(0),
      is_aromatic(false), is_metal(false), excess_electrons(0), charge(0.0f),
      par_charge(0.0f), bonding_idx(0) {}
};

// Statistical value with count
struct CodStats {
  double value = NAN;
  double sigma = NAN;
  int count = 0;
  int level = 0;  // match specificity: 0=none, 1-4=aggregated, 10=full

  CodStats() = default;
  CodStats(double v, double s, int c, int lvl = 0) : value(v), sigma(s), count(c), level(lvl) {}
};

// Protonated hydrogen distances (both electron cloud and nucleus)
struct ProtHydrDist {
  double electron_val = NAN;
  double electron_sigma = NAN;
  double nucleus_val = NAN;
  double nucleus_sigma = NAN;
};

// Metal bond entry
struct MetalBondEntry {
  Element metal = El::X;
  Element ligand = El::X;
  int metal_coord = 0;
  int ligand_coord = 0;
  std::string ligand_class;
  double pre_value = NAN;
  double pre_sigma = NAN;
  int pre_count = 0;
  double value = NAN;
  double sigma = NAN;
  int count = 0;
};

// Metal coordination angle entry
struct MetalAngleEntry {
  Element metal = El::X;
  int coord_number = 0;
  CoordGeometry geometry = CoordGeometry::UNKNOWN;
  double angle = NAN;
  double sigma = NAN;
};

struct MetalCoordOverride {
  Element metal = El::X;
  int coord_number = 0;
  CoordGeometry geometry = CoordGeometry::UNKNOWN;
};

struct TorsionEntry {
  double value = 0.0;
  double sigma = 0.0;
  int period = 0;
  int priority = 0;
  std::string id;
};

// ============================================================================
// Main AcedrgTables class
// ============================================================================

struct GEMMI_DLL AcedrgTables {
  AcedrgTables() = default;

  // Load all tables from directory
  void load_tables(const std::string& tables_dir, bool skip_angles = false);

  // Process a ChemComp - fill all missing restraint values
  void fill_restraints(ChemComp& cc) const;

  // Assign CCP4 atom energy types (type_energy) following AceDRG rules
  void assign_ccp4_types(ChemComp& cc) const;
  bool lookup_pep_tors(const std::string& a1, const std::string& a2,
                                 const std::string& a3, const std::string& a4,
                                 TorsionEntry& out) const;
  bool lookup_nucl_tors(const std::string& a1, const std::string& a2,
                        const std::string& a3, const std::string& a4,
                        std::vector<TorsionEntry>& out) const;

  // Individual lookups - returns match level (10=full, 4+=neighbor matched, 0-3=aggregated)
  int fill_bond(const ChemComp& cc,
                const std::vector<CodAtomInfo>& atom_info,
                Restraints::Bond& bond) const;
  int fill_angle(const ChemComp& cc,
                           const std::vector<CodAtomInfo>& atom_info,
                           Restraints::Angle& angle,
                           const std::set<int>& needed_files) const;

  // Atom classification - returns info for all atoms
  std::vector<CodAtomInfo> classify_atoms(const ChemComp& cc) const;

  // Compute acedrg_type string (like acedrg --typeOut)
  // Format: CentralElement(Neighbor1_desc)(Neighbor2_desc)...
  // where each neighbor description = neighbor element + sorted neighbor's other neighbors
  std::string compute_acedrg_type(const CodAtomInfo& atom,
                                  const std::vector<CodAtomInfo>& atoms,
                                  const std::vector<std::vector<int>>& neighbors) const;
  std::vector<std::string> compute_acedrg_types(const ChemComp& cc) const;

  // Configuration
  double upper_bond_sigma = 0.02;
  double lower_bond_sigma = 0.01;
  double upper_angle_sigma = 3.0;
  double lower_angle_sigma = 1.5;
  int min_observations_angle = 3;  // AceDRG default for angles
  int min_observations_angle_fallback = 3;
  int min_observations_bond = 3;   // AceDRG default for bonds (aNumTh=3)
  int metal_class_min_count = 5; // AceDRG uses >5 for metal class selection
  int verbose = 0;  // Debug output level (0=off, 1=basic, 2=detailed)

  static constexpr int HASH_SIZE = 1000;
  // Table directory
  std::string tables_dir_;
  bool tables_loaded_ = false;

  // Hash code tables
  std::map<int, std::string> digit_keys_;  // hash -> footprint
  std::map<int, int> linked_hash_;         // hash -> linked hash

  // HRS (High-Resolution Summary) bond tables
  // Key: hash1, hash2, hybrid_pair, in_ring
  struct BondHRSKey {
    int hash1, hash2;
    std::string hybrid_pair;
    std::string in_ring;
    bool operator<(const BondHRSKey& o) const {
      return std::tie(hash1, hash2, hybrid_pair, in_ring)
           < std::tie(o.hash1, o.hash2, o.hybrid_pair, o.in_ring);
    }
  };
  std::map<BondHRSKey, CodStats> bond_hrs_;

  // HRS angle tables
  // Key: hash1, hash2, hash3, value_key (ring:hybr_tuple)
  struct AngleHRSKey {
    int hash1, hash2, hash3;
    std::string value_key;
    bool operator<(const AngleHRSKey& o) const {
      return std::tie(hash1, hash2, hash3, value_key)
           < std::tie(o.hash1, o.hash2, o.hash3, o.value_key);
    }
  };
  std::map<AngleHRSKey, CodStats> angle_hrs_;

  struct Ccp4BondEntry {
    double length = NAN;
    double sigma = NAN;
  };

  std::map<std::string, std::map<std::string, std::map<std::string, Ccp4BondEntry>>> ccp4_bonds_;

  void load_ccp4_bonds(const std::string& path);
  std::vector<std::string> compute_ccp4_types(const ChemComp& cc,
                                              const std::vector<CodAtomInfo>& atom_info,
                                              const std::vector<std::vector<int>>& neighbors) const;
  bool search_ccp4_bond(const std::string& type1,
                        const std::string& type2,
                        const std::string& order,
                        CodStats& out) const;

  // Detailed indexed bond tables from allOrgBondTables/*.table
  // Flattened: 8-component compound key (ha1|ha2|hybr|ring|a1nb2|a2nb2|a1nb|a2nb)
  //   -> 2 inner map levels (a1_type_m -> a2_type_m -> vector<CodStats>)
  using BondIdx1D = std::unordered_map<std::string,
    std::map<std::string, std::map<std::string, std::vector<CodStats>>>>;
  BondIdx1D bond_idx_1d_;

  // Exact match with full COD class
  // Flattened: 10-component key (ha1|..|a1_type_m|a2_type_m)
  //   -> 2 inner levels (a1_type_f -> a2_type_f -> CodStats)
  using BondIdxFull = std::unordered_map<std::string,
    std::map<std::string, std::map<std::string, CodStats>>>;
  BondIdxFull bond_idx_full_;

  // Levels 3-8: 4-component key (ha1|ha2|hybr|ring)
  //   -> 4 inner levels (a1nb2 -> a2nb2 -> a1nb -> a2nb -> vector<CodStats>)
  using BondIdx2D = std::unordered_map<std::string,
    std::map<std::string, std::map<std::string,
    std::map<std::string, std::map<std::string,
    std::vector<CodStats>>>>>>;
  BondIdx2D bond_idx_2d_;

  // Level Nb2D: 6-component key (ha1|ha2|hybr|ring|a1nb2|a2nb2)
  using BondNb2D = std::unordered_map<std::string, std::vector<CodStats>>;
  BondNb2D bond_nb2d_;  // populated but not read (dead data)

  // Level Nb2DType: 8-component key (ha1|ha2|hybr|ring|a1nb2|a2nb2|root1|root2)
  using BondNb2DType = std::unordered_map<std::string, std::vector<CodStats>>;
  BondNb2DType bond_nb2d_type_;  // populated but not read (dead data)

  // Levels 9-11: Hash+Sp fallback, fully flat compound keys
  using BondHaSp2D = std::unordered_map<std::string, std::vector<CodStats>>;
  BondHaSp2D bond_hasp_2d_;  // 4-component key (ha1|ha2|hybr|ring)
  using BondHaSp1D = std::unordered_map<std::string, std::vector<CodStats>>;
  BondHaSp1D bond_hasp_1d_;  // 3-component key (ha1|ha2|hybr)
  using BondHaSp0D = std::unordered_map<std::string, std::vector<CodStats>>;
  BondHaSp0D bond_hasp_0d_;  // 2-component key (ha1|ha2)

  // Bond file index: flat 2-component key (ha1|ha2) -> table file number
  std::unordered_map<std::string, int> bond_file_index_;

  // Side sets for prefix-existence checks in bond lookup
  std::unordered_set<std::string> bond_2d_hybr_keys_;     // 3-component (ha1|ha2|hybr)
  std::unordered_set<std::string> bond_full_4prefix_keys_; // 4-component (ha1|ha2|hybr|ring)

  // Atom type code mapping: coded -> full type string
  std::unordered_map<std::string, std::string> atom_type_codes_;

  // Detailed indexed angle tables from allOrgAngleTables/*.table
  // Fully flat compound keys for levels 1D-4D, 6D
  using AngleIdx1D = std::unordered_map<std::string, std::vector<CodStats>>;
  AngleIdx1D angle_idx_1d_;   // 16-component key
  using AngleIdx2D = std::unordered_map<std::string, std::vector<CodStats>>;
  AngleIdx2D angle_idx_2d_;   // 13-component key
  using AngleIdx3D = std::unordered_map<std::string, std::vector<CodStats>>;
  AngleIdx3D angle_idx_3d_;   // 10-component key
  using AngleIdx4D = std::unordered_map<std::string, std::vector<CodStats>>;
  AngleIdx4D angle_idx_4d_;   // 7-component key

  // Level 5D: kept nested for wildcard iteration in fill_angle
  using AngleIdx5D = std::unordered_map<int, std::unordered_map<int,
    std::unordered_map<int, std::map<std::string, std::vector<CodStats>>>>>;
  AngleIdx5D angle_idx_5d_;

  // Level 6D: flat 3-component key (ha1|ha2|ha3)
  using AngleIdx6D = std::unordered_map<std::string, std::vector<CodStats>>;
  AngleIdx6D angle_idx_6d_;

  // Angle file index: kept nested for individual-level access in fill_angle
  std::unordered_map<int, std::unordered_map<int, std::unordered_map<int, int>>> angle_file_index_;

  // Element + hybridization based fallback bonds
  using ENBonds = std::map<std::string, std::map<std::string,
    std::map<std::string, std::map<std::string,
    std::vector<CodStats>>>>>;
  ENBonds en_bonds_;

  // Metal bond tables
  std::vector<MetalBondEntry> metal_bonds_;
  std::array<double, static_cast<int>(El::END) + 1> covalent_radii_;
  std::vector<MetalAngleEntry> metal_angles_;
  std::vector<MetalCoordOverride> metal_coord_geo_overrides_;
  std::map<std::string, TorsionEntry> pep_tors_;
  std::map<std::string, std::vector<TorsionEntry>> nucl_tors_;

  // Protonated hydrogen distances: maps type (e.g., "H_sp3_C") -> ProtHydrDist
  std::map<std::string, ProtHydrDist> prot_hydr_dists_;

  // Internal helper functions

  // Loading functions
  void load_hash_codes(const std::string& path);
  void load_bond_hrs(const std::string& path);
  void load_angle_hrs(const std::string& path);
  void load_metal_tables(const std::string& dir);
  void load_covalent_radii(const std::string& path);
  void load_en_bonds(const std::string& path);
  void load_atom_type_codes(const std::string& path);
  void load_bond_index(const std::string& path);
  void load_bond_tables(const std::string& dir);
  void load_pep_tors(const std::string& path);
  void load_nucl_tors(const std::string& path);
  void load_prot_hydr_dists(const std::string& path);
  void load_angle_index(const std::string& path);
  void load_angle_tables(const std::string& dir);

 private:
  void compute_hash(CodAtomInfo& atom) const;

  // Bond search helpers
  CodStats search_bond_multilevel(const CodAtomInfo& a1,
                                    const CodAtomInfo& a2) const;
  CodStats search_bond_hrs(const CodAtomInfo& a1, const CodAtomInfo& a2,
                             bool in_ring) const;
  CodStats search_bond_en(const CodAtomInfo& a1, const CodAtomInfo& a2) const;
  ProtHydrDist search_prot_hydr_dist(const CodAtomInfo& h_atom,
                                     const CodAtomInfo& heavy_atom) const;
  CodStats search_metal_bond(const CodAtomInfo& metal,
                               const CodAtomInfo& ligand,
                               const std::vector<CodAtomInfo>& atoms) const;
  // Angle search helpers
  CodStats search_angle_multilevel(const CodAtomInfo& a1,
                                     const CodAtomInfo& center,
                                     const CodAtomInfo& a3,
                                     int* out_level = nullptr) const;
  CodStats search_angle_hrs(const CodAtomInfo& a1, const CodAtomInfo& center,
                              const CodAtomInfo& a3, int ring_size) const;
  std::vector<double> get_metal_angles(Element metal, int coord_number) const;

  // Utility: clamp sigma to reasonable range
  double clamp_bond_sigma(double sigma) const {
    return std::max(lower_bond_sigma, std::min(upper_bond_sigma, sigma));
  }
  double clamp_angle_sigma(double sigma) const {
    return std::max(lower_angle_sigma, std::min(upper_angle_sigma, sigma));
  }
};

} // namespace gemmi

#endif
