// graph-tool -- a general graph modification and manipulation thingy
//
// Copyright (C) 2006-2025 Tiago de Paula Peixoto <tiago@skewed.de>
//
// This program is free software; you can redistribute it and/or modify it under
// the terms of the GNU Lesser General Public License as published by the Free
// Software Foundation; either version 3 of the License, or (at your option) any
// later version.
//
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
// details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef GRAPH_BLOCKMODEL_UNCERTAIN_MCMC_HH
#define GRAPH_BLOCKMODEL_UNCERTAIN_MCMC_HH

#include "config.h"

#include <vector>

#include "graph_tool.hh"
#include "../support/graph_state.hh"
#include "../loops/mcmc_loop.hh"
#include "sample_edge_sbm.hh"

#include "../support/sample_util.hh"
#include "../support/bisection_sampler.hh"

namespace graph_tool
{
using namespace boost;
using namespace std;

typedef std::vector<size_t> vlist_t;

#define MCMC_UNCERTAIN_STATE_params(State)                                     \
    ((__class__,&, decltype(hana::tuple_t<python::object>), 1))                \
    ((state, &, State&, 0))                                                    \
    ((beta,, double, 0))                                                       \
    ((entropy_args,, uentropy_args_t, 0))                                      \
    ((pedges,, double, 0))                                                     \
    ((puniform,, double, 0))                                                   \
    ((bisect,, bool, 0))                                                       \
    ((verbose,, int, 0))                                                       \
    ((niter,, size_t, 0))


template <class State>
struct MCMC
{
    GEN_STATE_BASE(MCMCUncertainStateBase, MCMC_UNCERTAIN_STATE_params(State))

    template <class... Ts>
    class MCMCUncertainState
        : public MCMCUncertainStateBase<Ts...>,
          public MetropolisStateBase
    {
    public:
        GET_PARAMS_USING(MCMCUncertainStateBase<Ts...>,
                         MCMC_UNCERTAIN_STATE_params(State))
        GET_PARAMS_TYPEDEF(Ts, MCMC_UNCERTAIN_STATE_params(State))

        template <class... ATs,
                  typename std::enable_if_t<sizeof...(ATs) ==
                                            sizeof...(Ts)>* = nullptr>
        MCMCUncertainState(ATs&&... as)
            : MCMCUncertainStateBase<Ts...>(as...),
              _edge_sampler(_state._block_state, _pedges, _puniform),
              _vlist(num_vertices(_state._u))
        {
        }

        SBMEdgeSampler<typename State::block_state_t> _edge_sampler;

        std::tuple<size_t, size_t> _e;
        std::vector<size_t> _vlist;
        int64_t _null_move = 0;

        std::tuple<size_t, size_t> get_edge()
        {
            return _e;
        }

        size_t node_state(size_t u, size_t v)
        {
            auto&& e = _state.get_u_edge(u, v);
            if (e == _state._null_edge)
                return 0;
            return _state._eweight[e];
        }

        size_t node_state(size_t)
        {
            size_t u, v;
            std::tie(u, v) = get_edge();
            return node_state(u, v);
        }

        BisectionSampler<int64_t> _bisection_sampler;

        template <class RNG>
        int64_t move_proposal(size_t, RNG& rng)
        {
            auto [u, v] = _e = _edge_sampler.sample(rng);
            int64_t m = node_state(u, v);

            if (_bisect)
            {
                _bisection_sampler =
                    BisectionSampler<int64_t>
                    ([&](auto nm)
                     {
                         int64_t dm = nm - m;
                         if (dm < 0)
                             return _state.remove_edge_dS(u, v, -dm, _entropy_args);
                         else
                             return _state.add_edge_dS(u, v, dm, _entropy_args);
                     }, bisect_args_t());
                auto nm = _bisection_sampler.bisect_fb(0, _state._max_m, rng);
                if (std::isinf(_beta))
                    return nm - m;
                return _bisection_sampler.sample(1, 1, rng) - m;
            }

            return sample_geometric(m + 2, rng) - m;
        }

        std::tuple<double, double>
        virtual_move_dS(size_t, int64_t dm)
        {
            if (dm == 0)
                return {0., 0.};

            size_t u, v;
            std::tie(u, v) = get_edge();

            double dS = 0;
            if (dm < 0)
                dS = _state.remove_edge_dS(u, v, -dm, _entropy_args);
            else
                dS = _state.add_edge_dS(u, v, dm, _entropy_args);

            size_t m = node_state(u, v);
            double a = (_edge_sampler.log_prob(u, v, m, dm) -
                        _edge_sampler.log_prob(u, v, m, 0));

            size_t nm = m + dm;

            if (_bisect)
            {
                a -= _bisection_sampler.lprob(nm, 1, 1);
                a += _bisection_sampler.lprob(m, 1, 1);
            }
            else
            {
                a -= geometric_lprob(nm, m + 2);
                a += geometric_lprob(m, nm + 2);
            }

            return std::make_tuple(dS, a);
        }

        void perform_move(size_t, int64_t dm)
        {
            if (dm == 0)
                return;

            size_t u, v;
            std::tie(u, v) = get_edge();
            size_t m = node_state(u, v);

            if (dm < 0)
            {
                _edge_sampler.update_edge(u, v, m, dm);
                _state.remove_edge(u, v, -dm);
            }
            else
            {
                _state.add_edge(u, v, dm);
                _edge_sampler.update_edge(u, v, m, dm);
            }
        }

        bool is_deterministic()
        {
            return false;
        }

        bool is_sequential()
        {
            return false;
        }

        auto& get_vlist()
        {
            return _vlist;
        }

        double get_beta()
        {
            return _beta;
        }

        size_t get_niter()
        {
            return _niter;
        }
    };
};


} // graph_tool namespace

#endif //GRAPH_BLOCKMODEL_UNCERTAIN_MCMC_HH
