#include <iostream>
#include <fstream>
#include <cmath>
#include <string>
using namespace std;
// routine for iterative solver gmres
#include "gmres.h"
#include "bicgstab.h"
#include "cghs.h"
#include "vectorT.hh"
// problem description:
// solve -u'' = f on (-1,1) and u(-1)=u(1)=0
// using piecewise linear ansatzfunctions
class Data {
 public:
  // solution (if exists otherwise return 0)
  virtual double lsg(double x) const {
    return 0.;
  }
  // right hand side f
  virtual double f(double x) const = 0;
  // information
  virtual string info() = 0;
};
class DataSin : public Data { 
  // define right hand side so that exact solution is
  // u(x) = tanh(epsilon*(x-x0))*sin(pi*x)
  const double pi;
 public:
  DataSin() :
    pi(1.5*M_PI) {}
  // exact solution to right hand side f (given next)
  virtual double lsg(double x) const {  
    return (x*x-1.)*sin(pi*x);
  }
  // right hand side f
  virtual double f(double x) const {
    double t0 = (2.-x*x*pi*pi+pi*pi)*sin(pi*x)+4.*x*pi*cos(pi*x);
    return -t0;
  }
  // information
  virtual string info() {
    return "Sin Problem";
  }
};
class DataTanh : public Data { 
  // define right hand side so that exact solution is
  // u(x) = tanh(epsilon*(x-x0))*sin(pi*x)
  const double epsilon;
  const double pi;
  const double x0;
 public:
  DataTanh() :
    epsilon(113.),
    pi(1.5*M_PI),
    x0(M_PI/sqrt(31.)) {}
  // exact solution to right hand side f (given next)
  virtual double lsg(double x) const {  
    return tanh(epsilon*(x-x0))*sin(pi*x)+1;
  }
  // right hand side f
  virtual double f(double x) const {
    double t0 = -2.0*tanh(epsilon*(x-x0))*(1.0-pow(tanh(epsilon*(x-x0)),2.0))*
	        epsilon*epsilon*sin(pi*x)+2.0*(1.0-pow(tanh(epsilon*(x-x0)),2.0))*
		epsilon*cos(pi*x)*pi-tanh(epsilon*(x-x0))*sin(pi*x)*pi*pi;
    return -t0;
  }
  // information
  virtual string info() {
    return "Tanh Problem";
  }
};
// A simple 1d Grid with uniform grid density
class Grid {
 protected:
  int n;                         // number of subintervals 
  VectorT<double> h_;             // grid density  
  VectorT<double> x_;             // start point of intervals
  double haverage_;              // average grid density
 public:
  // constructor: parameters for n,alpha,beta
  Grid(int pn) :        
    n(pn), x_(n), h_(n) {
      haverage_=2./(double)(n);  // compute average grid density
      x_[0] = -1.;
      for (int i=1;i<n;i++) {
	h_[i-1] = haverage_;
	x_[i] = x_[i-1]+h_[i-1];
      }
      h_[n-1] = 1.-x_[n-1];
  }
  ~Grid() {
  }
  int size() const {
    return n;
  }						
  double h(int i) const {
    return h_[i];
  }
  double x(int i,double l) const {  
    // map local to global coordinate on subinterval i
    return xl(i)+l*h_[i];
 }
  double xm(int i) const {  
    // midpoint of cell i
    return x_[i]+0.5*h_[i];
  }
  double xl(int i) const {  
    // left point of cell i
    return x_[i];
  }
  double xr(int i) const {  
    // right point of cell i
    return x_[i]+h_[i];
  }
};
// Discrete Functions
class DiscLinFunc{
 protected:
  const Grid& grid;
  int n;                      // number of dofs
  VectorT<double> U;           // vectors for right side and solution
 public:
  // constructor: parameters for n,alpha,beta
  DiscLinFunc(const Grid& pgrid) :  
    grid(pgrid),
    U(2*grid.size()) {
    }
  ~DiscLinFunc() {  // destructor
  }
  int baseSize() const {
    return 2;
  }
  double phi(int k,double l) const {
    // phihat_k(l)
    if (k==0) return 1.;
    else return l-0.5;
  }
  double operator()(int i,double l) const {  
    // evaluate solution on subinterval i with local coordinate l
    return U[2*i]+U[2*i+1]*(l-0.5);
  }
  double& dof(int i,int k) {
    // return local dof k on cell i
    return U[2*i+k];
  }
  double* raw() {
    // a bad hack (required for use in iterative solver)
    return U.raw();
  }
};
class ContLinFunc {
 protected:
  const Grid& grid;
  int n;                      // number of dofs
  VectorT<double> U;           // vectors for right side and solution
 public:
  // constructor: parameters for n,alpha,beta
  ContLinFunc(const Grid& pgrid) :  
    grid(pgrid),
    U(grid.size()+1) {  // solution vector 
                        // (including boundary points
    }
  ~ContLinFunc() {  // destructor
  }
  int baseSize() const {
    // number of base functions on reference element
    return 2;
  }
  double phi(int k,double l) const {
    // phihat_k(l)
    if (k==0) return 1.-l;
    else return l;
  }
  double operator()(int i,double l) const {  
    // evaluate solution on subinterval i with local coordinate l
    return U[i]*(1.-l)+U[i+1]*l;
  }
  double& dof(int i,int k) {
    // return local dof k on cell i
    return U[i+k];
  }
  double* raw() {
    // a bad hack (required for use in iterative solver)
    return U.raw();
  }
};
/********************************************************
   CLASS for DG Method
********************************************************/
// method needed for iterative solver. Should be friend in DG therefore
// some forward declaration. The method itself is defined below
template <class Data> class DG;
template <class Data> void mult(const DG<Data>&, const double *, double *);
// discretization class
template <class Data>
class DG {
 protected:
  const Grid grid;
  const Data& dat;
  double alpha,beta;             // stabilization parameters for dg method
  DiscLinFunc F;                     // vectors for right side 
  DiscLinFunc U;                     // solution 
 public:
  // constructor: parameters for n,alpha,beta
  DG(Data& pdat,int pn,double palpha,double pbeta) :  
    grid(pn),
    dat(pdat),
    alpha(palpha), beta(pbeta), 
    F(grid),
    U(grid) {
    // assemble right hand side using 2point gauss quadrature
    const double l0 = (1.-sqrt(3.)/3.)*0.5; // gauss ... 
    const double l1 = (1.+sqrt(3.)/3.)*0.5; // ... points 
    const double weight = 0.5;
    for (int i=0;i<grid.size();i++) {
      for (int k=0;k<F.baseSize();k++) {
	F.dof(i,k)  = F.phi(k,l0) * dat.f(grid.x(i,l0)) * grid.h(i)*weight;
	F.dof(i,k) += F.phi(k,l1) * dat.f(grid.x(i,l1)) * grid.h(i)*weight;
      }
    }
  }
  ~DG() {  // destructor
  }
  double error() const {  
    // compute L^2 error (using 2 point Gauss on each element
    double err = 0.;
    const double l0 = (1.-sqrt(3.)/3.)*0.5;
    const double l1 = (1.+sqrt(3.)/3.)*0.5;
    for (int i=0;i<grid.size();i++) {
      double lerror = 0;
      lerror += pow(U(i,l0)-dat.lsg(grid.x(i,l0)),2.)*0.5;
      lerror += pow(U(i,l1)-dat.lsg(grid.x(i,l1)),2.)*0.5;
      err += grid.h(i)*lerror;
    }
    return sqrt(err);
  }
  void print(ofstream &out) const {  
    // print result in fileq
    for (int i=0;i<grid.size();i++) {
      out << grid.x(i,0.25) << " " << U(i,0.25) 
	  << " " << dat.lsg(grid.x(i,0.25)) << endl;
      out << grid.x(i,0.75) << " " << U(i,0.75)  
	  << " " << dat.lsg(grid.x(i,0.75)) << endl;
      out << endl;
    }
    out << endl;
  }
  int solve(double tol) {
    int iter;
    if (alpha==-1)
        iter=cghs(U.baseSize()*grid.size(),*this,
                  F.raw(),U.raw(),tol,false);
    else {  // two possible non-sym. solvers
//      iter=bicgstab(U.baseSize()*grid.size(),*this,
//                    F.raw(),U.raw(),tol,true);
      iter=gmres(grid.size()/5,
		 U.baseSize()*grid.size(),*this,F.raw(),U.raw(),tol,false);
    }
    return iter;
  }
  friend void mult<>(const DG<Data>&, const double*, double*);
};
// template <class Data>
// void mult(const DG<Data> &R, const double *v, double *w) {
//   // perform multiplication of SM A*v = w for w givven
//   // matrix for nodal basis with fixed size grid
//   double alpha=R.alpha;
//   double beta=R.beta;
//   for (int i=0;i<R.grid.size();i++) {
//     w[2*i]   = (3.-alpha+2.*beta)*v[2*i]  +(-3.+alpha)*v[2*i+1];
//     w[2*i+1] = (3.-alpha+2.*beta)*v[2*i+1]+(-3.+alpha)*v[2*i];
//     if (i>0) {
//       w[2*i]   += v[2*(i-1)]+(-1.+alpha-2.*beta)*v[2*(i-1)+1];
//       w[2*i+1] += (-alpha)*v[2*(i-1)+1];
//     }
//     if (i<R.grid.size()-1) {
//       w[2*i]   += (-alpha)*v[2*(i+1)];
//       w[2*i+1] += (-1.+alpha-2.*beta)*v[2*(i+1)]+v[2*(i+1)+1];
//     }
//     w[2*i]   /= 2.*R.grid.h(i);
//     w[2*i+1] /= 2.*R.grid.h(i);
//   }
// }
template <class Data>
void mult(const DG<Data> &R, const double *v, double *w) {
  // perform multiplication of SM A*v = w for w givven
  // matrix for nodal basis with fixed size grid
  double alpha=R.alpha;
  double beta=R.beta;
  double h = R.grid.h(0); // assuming uniform grid
  
  for (int i=0;i<R.grid.size();i++) {
    // diagonal blocks:
    if (i==0)
    {  
      w[2*i]   = (2.*beta)* v[2*i]    + 1./2.*                      v[2*i+1];
      w[2*i+1] = (-alpha/2.) * v[2*i] + (1.+3.*alpha + 2.*beta)/4.0*v[2*i+1];
    }
    else if (i==(R.grid.size()-1))
    {
      w[2*i]   = (2.*beta)* v[2*i]  - 1./2.*                     v[2*i+1];
      w[2*i+1] =  alpha/2.0 * v[2*i]+(1.+3.*alpha + 2.*beta)/4.0*v[2*i+1]; 
    }
    else 
    {
      w[2*i]   = (2.*beta)*v[2*i];
      w[2*i+1] = (1.+alpha +beta)/2.0*v[2*i+1];      
    }

    // lower block-diagonal 
    if (i>0) {
      w[2*i]   += -beta *           v[2*(i-1)]   + (1.-beta)/2.*v[2*(i-1)+1];
      w[2*i+1] +=  (beta+alpha)/2.* v[2*(i-1)]   + (beta+alpha-1.)/4.*v[2*(i-1)+1];
    }
    // upper block-diagonal 
    if (i<R.grid.size()-1) {
      w[2*i]   += -beta   *        v[2*(i+1)] +(beta-1.)/2.*      v[2*(i+1)+1];
      w[2*i+1] += -(beta+alpha)/2.*v[2*(i+1)] +(beta-1.+alpha)/4.*v[2*(i+1)+1];
    }
    w[2*i]   /= h;
    w[2*i+1] /= h;
  }
}
/********************************************************
   CLASS for CG Method
********************************************************/
// method needed for iterative solver. Should be friend in DG therefore
// some forward declaration. The method itself is defined below
template <class Data> class CG;
template <class Data> void mult(const CG<Data>&, const double *, double *);
// discretization class
template <class Data>
class CG {
 protected:
  const Grid grid;
  const Data& dat;
  double alpha,beta;             // stabilization parameters for dg method
  ContLinFunc F;                 // vectors for right side 
  ContLinFunc U;                 // solution 
 public:
  // constructor: parameters for n,alpha,beta
  CG(Data& pdat,int pn,double palpha,double pbeta) :  
    grid(pn),
    dat(pdat),
    alpha(palpha), beta(pbeta), 
    F(grid),
    U(grid) {
    // assemble right hand side using 2point gauss quadrature
    const double l0 = (1.-sqrt(3.)/3.)*0.5; // gauss ... 
    const double l1 = (1.+sqrt(3.)/3.)*0.5; // ... points 
    const double weight = 0.5;
    for (int i=0;i<grid.size();i++) {
      for (int k=0;k<F.baseSize();k++) {
	F.dof(i,k) += F.phi(k,l0) * dat.f(grid.x(i,l0)) * grid.h(i)*weight;
	F.dof(i,k) += F.phi(k,l1) * dat.f(grid.x(i,l1)) * grid.h(i)*weight;
      }
    }
    // boundary data
    F.dof(0,0) = 0.;
    F.dof(grid.size()-1,1) = 0;
  }
  ~CG() {  // destructor
  }
  double error() const {  
    // compute L^2 error (using 2 point Gauss on each element
    double err = 0.;
    const double l0 = (1.-sqrt(3.)/3.)*0.5;
    const double l1 = (1.+sqrt(3.)/3.)*0.5;
    for (int i=0;i<grid.size();i++) {
      double lerror = 0;
      lerror += pow(U(i,l0)-dat.lsg(grid.x(i,l0)),2.)*0.5;
      lerror += pow(U(i,l1)-dat.lsg(grid.x(i,l1)),2.)*0.5;
      err += grid.h(i)*lerror;
    }
    return sqrt(err);
  }
  void print(ofstream &out) const {  
    // print result in fileq
    for (int i=0;i<grid.size();i++) {
      out << grid.x(i,0.25) << " " << U(i,0.25) 
	  << " " << dat.lsg(grid.x(i,0.25)) << endl;
      out << grid.x(i,0.75) << " " << U(i,0.75)  
	  << " " << dat.lsg(grid.x(i,0.75)) << endl;
      out << endl;
    }
    out << endl;
  }
  int solve(double tol) {
    // int iter=gmres(grid.size()/10,grid.size()+1,*this,F.raw(),U.raw(),tol,false);
    int iter=cghs(grid.size()+1,*this,F.raw(),U.raw(),tol,false);
   return iter;
  }
  friend void mult<>(const CG<Data>&, const double*, double*);
};
template <class Data>
void mult(const CG<Data> &R, const double *v, double *w) {
  // perform multiplication of SM A*v = w for v givven
  // matrix for nodal basis
  w[0] = 0;
  w[R.grid.size()] = 0.;
  for (int i=1;i<R.grid.size();i++) {
    w[i]   =  v[i]*(R.grid.h(i)+R.grid.h(i-1))/(R.grid.h(i)*R.grid.h(i-1));
    if (i>1) 
      w[i]   -= v[i-1]/R.grid.h(i-1);
    if (i<R.grid.size()-1) 
      w[i]   -= v[i+1]/R.grid.h(i);
  }
}
// now the call is very easy ...
int main(int argc, char ** argv, char ** envp) {
  // read command line parameters
  if (argc<5) {
    cout << "Wrong number of parameter:" << endl;
    cout << "Call: rwp N alpha beta problem" << endl;
    return 0;
  }
  int n=atoi(argv[1]);           
  double alpha=atof(argv[2]);    
  double beta=atof(argv[3]); 
  int prob = atoi(argv[4]);
  Data *dat = NULL;
  switch (prob) {
  case 1: dat = new DataSin(); break;
  case 2: dat = new DataTanh(); break;
  default: cout << "Wrong Problem" << endl; abort();
  }
  int iter;
  cout << dat->info() << endl; 
  DG<Data> dg(*dat,n,alpha,beta);    
  // solve 
  iter = dg.solve(1e-6);
  // print result
  cout << "DG required " << iter << " iterations" << flush;
  cout << " error: " << dg.error() << endl;
  ofstream outdg("rwp_dg.gnu");
  dg.print(outdg);
  CG<Data> cg(*dat,n,alpha,beta);    
  // solve 
  iter = cg.solve(1e-6);
  // print result
  cout << "CG required " << iter << " iterations" << flush;
  cout << " error: " << cg.error() << endl;
  ofstream outcg("rwp_cg.gnu");
  cg.print(outcg);
}
 

