// Copyright (c) 2002-2010 Wieger Wesselink
//
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)

/// \file sat/reach.h
/// \brief Contains reachability algorithm

#ifndef SAT_REACH_H
#define SAT_REACH_H

#include <vector>
#include <string>
#include <fstream>
#include <sstream>
#include <map>
#include "sat/variable_set.h"

// Very simplistic function to determine if a file exists
inline
bool exists_file(const std::string& filename)
{
  std::ifstream from(filename.c_str());
  return (from);
}

namespace sat {

///////////////////////////////////////////////////////////////////////////////
// ReachAlgorithm
/// \brief implements a reachability algorithm.
///
/// The transitions are defined by the transition relation T:
///
///    T(x0,x1) <=> there is a transition from x0 to x1
///
/// The initial set of states is given by I, and the final set of
/// states by F.
///
/// The algorithm is used to generate the sequence of sets of states
/// x(i), i = 0, 1, 2, ... such that x(0) = I, and x(k+1) is the set
/// of all states x that satisfy T(x(k),x).
///
/// To distinguish between the x0- and x1-variables, the parameters
/// variables0 and variables1 are used. These are mappings from bdd
/// indexes to strings and vice versa. They must share exactly the
/// same set of strings, and their indexes may not overlap. We denote
/// the x0 variables as '0' and the x1 variables as '1' variables.
template <class BDD>
class ReachAlgorithm
{
 protected:
    std::string state_file(int i) const
    {
      std::ostringstream os;
      os << filename << "_" << std::setw(3) << std::setfill('0') << i << ".bdd";
      return os.str().c_str();
    }

    void save_state() const
    {
      if (filename != "")
      {
        buddy::save(state_file(iteration), b);
      }
    }

 public:
    ReachAlgorithm(const BDD& I_,
                   const BDD& T_,
                   const BDD& F_,
                   const buddy::variable_set& variables0_,
                   const buddy::variable_set& variables1_
                  )
      : I(I_),
        T(T_),
        F(F_),
        variables0(variables0_),
        variables1(variables1_)
    {
      iteration = 0;
      b = I;
      init();
    }

    // virtual destructor
    virtual ~ReachAlgorithm()
    {}

    /// convert x from '1' to '0' variables
    ///
    BDD convert_to_0_variables(const BDD& x) const
    {
      return buddy::replace(x, permutation0);
    }

    /// convert x from '0' to '1' variables
    ///
    BDD convert_to_1_variables(const BDD& x) const
    {
      return buddy::replace(x, permutation1);
    }

    BDD get_inverse_transition_relation(const BDD& x) const
    {
      return buddy::permute(x, permutation0, 2*n);
    }

    const BDD& get_var0() const
    {
      return var0;
    }

    const BDD& get_var1() const
    {
      return var1;
    }

    // call this method to invert the problem. It implements the reachability
    // algorithm by starting in the final state and going backwards.
    ///
    virtual void invert()
    {
      std::swap(I, F);
      T = get_inverse_transition_relation(T);
      iteration = 0;
      b = I;
    }

    // call this method to store intermediate results
    ///
    void save_states(const std::string& filename_)
    {
      filename = filename_;
    }

    virtual void init()
    {
      // initialize permutation0 and permutation1
      n = variables0.size();
      for (int j = 0; j < n; j++)
      {
        int i0 = variables0.index(j);
        int i1 = variables1.index(variables0.name(j));
        permutation0[i0] = i1;
        permutation1[i1] = i0;
      }

      // initialize var0 and var1
      var0 = buddy::one();
      var1 = buddy::one();
      for (int j = 0; j < n; j++)
      {
        var0 &= bdd_ithvar(variables0.index(j));
        var1 &= bdd_ithvar(variables1.index(j));
      }

      // initialize variables01
      for (int j = 0; j < n; j++)
      {
        variables01.set_index(variables0.name(j) + "0", variables0.index(j));
        variables01.set_index(variables1.name(j) + "1", variables1.index(j));
      }
    }

    /// adds one transition to the state set x (???)
    ///
    virtual BDD add_transition(const BDD& x)
    {
      // use existential quantification to find x' such that Ex: b(x) && T(x,x')
      BDD next = buddy::exists(x, T, var0);

      next = convert_to_0_variables(next);

      // join next with x
      next |= x;

      return next;
    }

    /// returns false when no new states were found, or the final state has been reached
    /// pre:  iteration = i   and b contains all reachable states after iteration iterations
    /// post: iteration = i+1 and b contains all reachable states after iteration iterations
    ///
    virtual bool iterate(bool stop_at_solution = false)
    {
      if (stop_at_solution && ((b & F) != buddy::zero()))
        return false;

      BDD bnext = add_transition(b);
      bool last_iteration = (bnext == b);
      b = bnext;

      if (!last_iteration)
      {
        iteration++;
        save_state();
      }

// REORDERING
//int bsize = buddy::size(b);
//if (bsize > 150000)
//{
//  std::cout << "Before reordering: bdd size = " << buddy::size(b) << std::endl;
//  bdd_reorder(BDD_REORDER_SIFT   );
//  std::cout << "After reordering : bdd size = " << buddy::size(b) << std::endl;
//}
      return !last_iteration;
    }

    /// runs the algorithm and returns all reachable states that are contained in F
    ///
    virtual BDD run(bool stop_at_solution = false)
    {
      BDD tmp = T;
      save_state();
      while (iterate(stop_at_solution));
      return b & F;        // reachable states in '0' variables
    }

    /// returns the number of iterations that have been performed
    ///
    int get_iteration() const
    {
      return iteration;
    }

    /// returns the state after 'iteration' iterations
    ///
    const BDD& get_state() const
    {
      return b;
    }

    /// returns the transition relation T
    ///
    const BDD& get_transition_relation() const
    {
      return T;
    }

    /// returns true if T(x,y) holds
    ///
    virtual bool is_transition(const BDD& x, const BDD& y) const
    {
      BDD y1 = convert_to_1_variables(y);
      return (T & x & y1) != buddy::zero();
    }

    bool check_solution(const BDD& initial, const BDD& final, const std::vector<BDD>& x)
    {
      if (x.size() < 1)
      {
        std::cerr << "ReachAlgorithm::check_solution : empty solution!" << std::endl;
        return false;
      }
      if ((x.front() & initial) == buddy::zero())
      {
        std::cerr << "ReachAlgorithm::check_solution : solution start wrong!" << std::endl;
        return false;
      }
      if ((x.back() & final) == buddy::zero())
      {
        std::cerr << "ReachAlgorithm::check_solution : solution end wrong!" << std::endl;
        return false;
      }
      for (unsigned int i = 0; i < x.size() - 1; i++)
        if (!is_transition(x[i], x[i+1]))
        {
          std::cerr << "ReachAlgorithm::check_solution : transition " << i << " -> " << (i + 1) << " wrong!" << std::endl;
          return false;
        }
      return true;
    }

    /// returns all solutions y of T(x,y) in '0' variables
    ///
    virtual BDD get_target(const BDD& x) const
    {
      BDD y = buddy::exists(x, T, var0);
      y = convert_to_0_variables(y);
      return y;
    }

    /// returns all solutions x of T(x,y) in '0' variables
    ///
    virtual BDD get_source(const BDD& y) const
    {
      BDD y1 = convert_to_1_variables(y);
      BDD x = buddy::exists(y1, T, var1);
      return x;
    }

    /// checks how many states are available on disk and resumes
    /// at the last one
    ///
    virtual void resume()
    {
      int n = -1;
      for (;;)
      {
        std::string file = state_file(n + 1);
        if (exists_file(file))
          n++;
        else
          break;
      }

      if (n >= 0)
      {
        b = buddy::load(state_file(n));
        iteration = n;
      }
      else
      {
        std::cerr << "found nothing to resume!" << std::endl;
      }
    }

    /// returns the number of possible transitions
    ///
    double transition_count() const
    {
      return buddy::satcount(T, 2*n);
    }

 protected:
    BDD I;                                 // initial states
    BDD T;                                 // transition relation
    BDD F;                                 // final states
    buddy::variable_set variables0;           // '0' variables
    buddy::variable_set variables1;           // '1' variables
    buddy::variable_set variables01;          // '0' and '1' variables
    int iteration;                         // number of iterations performed
    int n;                                 // number of variables
    std::map<int, int> permutation0;       // permutation from '0' to '1' variables
    std::map<int, int> permutation1;       // permutation from '1' to '0' variables
    BDD var0;                              // conjunction of the '0' transition variables
    BDD var1;                              // conjunction of the '1' transition variables
    BDD b;                                 // the state after 'iteration' iterations
    std::string filename;                  // if (filename != "") then the state after i iterations
                                           // is stored in filename<i>.bdd
};

//---------------------------------------------------------//
//                    ReachAlgorithmPrint
//---------------------------------------------------------//
// This class prints information to a stream.
//
template <class BDD>
class ReachAlgorithmPrint: public ReachAlgorithm<BDD>
{
 public:
    ReachAlgorithmPrint(const BDD& I_,
                        const BDD& T_,
                        const BDD& F_,
                        const buddy::variable_set& variables0_,
                        const buddy::variable_set& variables1_,
                        std::ostream& out_
                       )
      : ReachAlgorithm<BDD>(I_, T_, F_, variables0_, variables1_),
        out(out_)
    {
    }

    // returns false when no new states were found, or the final state has been reached
    bool iterate(bool stop_at_solution = false)
    {
      if (stop_at_solution && ((this->b & this->F) != buddy::zero()))
        return false;

      bool result = ReachAlgorithm<BDD>::iterate(stop_at_solution);
      if (result)
        print_state();
      return result;
    }

    // runs the algorithm and returns all reachable states that are contained in F
    BDD run(bool stop_at_solution = false)
    {
      print_state();
      BDD result = ReachAlgorithm<BDD>::run(stop_at_solution);
      return result;
    }

 protected:
    virtual void print_state()
    {
      out << "ITERATION " << this->iteration << std::endl;
      out << "bdd size                   : " << buddy::size(this->b) << std::endl;
      out << "number of reachable states : " << buddy::satcount(this->b, this->n) << "\n" << std::endl;
    }

    std::ostream& out;
};

//---------------------------------------------------------//
//                    ReachAlgorithmBacktrack
//---------------------------------------------------------//
// This class has an additional method to backtrack a solution to
// the reachability problem. It has an abstract method 'get_state(int i)'
// that should return the state after i iterations.
//
template <class BDD>
class ReachAlgorithmBacktrack: public ReachAlgorithmPrint<BDD>
{
 public:
    ReachAlgorithmBacktrack(const BDD& I_,
                            const BDD& T_,
                            const BDD& F_,
                            const buddy::variable_set& variables0_,
                            const buddy::variable_set& variables1_,
                            std::ostream& out_
                           )
      : ReachAlgorithmPrint<BDD>(I_, T_, F_, variables0_, variables1_, out_)
    {
    }

    // returns the state after i iterations
    virtual BDD get_state(int i) const = 0;

    void init()
    {
      ReachAlgorithmPrint<BDD>::init();
      Tinv = get_inverse_transition_relation(this->T);
      backtrack_iteration = 0;
    }

    // returns the smallest index i, 0 <= i <= iteration such that
    // (get_state(i) & x) holds, or -1 if no such index exists
    int find_intersection(const BDD& x)
    {
      for (int i = 0; i <= this->iteration; i++)
      {
        if ((get_state(i) & x) != buddy::zero())
          return i;
      }
      return -1;
    }

    // returns false when backtrack_iteration has become 0
    // pre:  backtrack_iteration = i   and c[backtrack_iteration] in get_state(backtrack_iteration)
    // post: backtrack_iteration = i-1 and c[backtrack_iteration] in get_state(backtrack_iteration)
    virtual bool iterate_back()
    {
      if (backtrack_iteration == 0)
        return false;

      BDD prev = get_source(c[backtrack_iteration]);

      backtrack_iteration--;
      c[backtrack_iteration] = buddy::find_sat_solution(get_state(backtrack_iteration) & prev, this->variables0);

      return true;
    }

    // computes a shortest solution towards final_state
    // returns an empty vector if no such solution exists
    virtual std::vector<BDD> run_back(const BDD& final_state)
    {
      int i = find_intersection(final_state);
      // std::cout << "run_back intersection " << i << std::endl;

      if (i < 0)
        return std::vector<BDD>();

      // reserve the proper amount of space for the solution
      c = std::vector<BDD>(i + 1);

      backtrack_iteration = i;
      c[backtrack_iteration] = buddy::find_sat_solution(final_state & get_state(backtrack_iteration), this->variables0);
      while (iterate_back());

      //if (!check_solution(I, final_state, c))
      //  std::cout << "run_back solution incorrect!!!" << std::endl;

      return c;
    }

 protected:
    BDD Tinv;           // the inverse of the transition relation T
    int backtrack_iteration;
    std::vector<BDD> c; // after calling run_back, c will contain a sequence that satisfies
                        // -  c[0] in I
                        // -  T(c[i], c[i+1])
                        // -  c.back() in final_state
};

//---------------------------------------------------------//
//                    ReachAlgorithmBacktrackMemory
//---------------------------------------------------------//
// Backtrack from memory.
//
template <class BDD>
class ReachAlgorithmBacktrackMemory: public ReachAlgorithmBacktrack<BDD>
{
 public:
    ReachAlgorithmBacktrackMemory(const BDD& I_,
                                  const BDD& T_,
                                  const BDD& F_,
                                  const buddy::variable_set& variables0_,
                                  const buddy::variable_set& variables1_,
                                  std::ostream& out_
                                 )
      : ReachAlgorithmBacktrack<BDD>(I_, T_, F_, variables0_, variables1_, out_)
    {
    }

    // returns the state after i iterations
    BDD get_state(int i) const
    {
      return B[i];
    }

    void init()
    {
      ReachAlgorithmBacktrack<BDD>::init();
    }

    void resume()
    {
      // useless operation for ReachAlgorithmBacktrackMemory
    }

    BDD run(bool stop_at_solution = false)
    {
      B.clear();
      B.push_back(this->I);
      return ReachAlgorithmBacktrack<BDD>::run(stop_at_solution);
    }

    // returns false when no new states were found
    bool iterate(bool stop_at_solution = false)
    {
      bool result = ReachAlgorithmBacktrack<BDD>::iterate(stop_at_solution);
      if (result)
        B.push_back(this->b);
      return result;
    }

 protected:
    std::vector<BDD> B; // B[i] will contain the set of reachable states in '0' variables after <= i iterations
};

//---------------------------------------------------------//
//                    ReachAlgorithmBacktrackDisk
//---------------------------------------------------------//
// Backtrack from disk.
//
template <class BDD>
class ReachAlgorithmBacktrackDisk: public ReachAlgorithmBacktrack<BDD>
{
 public:
    ReachAlgorithmBacktrackDisk(const BDD& I_,
                                const BDD& T_,
                                const BDD& F_,
                                const buddy::variable_set& variables0_,
                                const buddy::variable_set& variables1_,
                                std::ostream& out_,
                                const std::string& filename_
                               )
      : ReachAlgorithmBacktrack<BDD>(I_, T_, F_, variables0_, variables1_, out_)
    {
      // set a filename, so that intermediate states are stored
      this->save_states(filename_);
    }

    // returns the state after i iterations
    BDD get_state(int i) const
    {
      //std::cout << "loading " << state_file(i) << " from disk!" << std::endl;
      return buddy::load(this->state_file(i));
    }
};

} // namespace sat

#endif // SAT_REACH_H
