#include "NagFitStrategy.h"
#include <TF1.h>
#include <TH1.h>
#include <TGraphErrors.h>

using namespace Sct;
namespace SctFitter{
   NagFitStrategy::NagFitStrategy(string opt) throw()
	: FitStrategy(opt), name("NagFitStrategy") {}
    
    NagFitStrategy::~NagFitStrategy() throw() {}

/// http://wwwasd.web.cern.ch/wwwasd/lhc++/Nag_C/CL/doc/e/e04fcc.html

    void NagFitStrategy::fitTH1(const TH1& hist, TF1& fit) const throw(LogicError, MathsError) {
	/// need to do this because ROOT hist.Fit() isn't const!
	TH1& theHist = const_cast<TH1&> (hist);
	// use the range of fit to find the number of fit points
	// and the first active bin.
	int firstpoint=1;
	const int npoints = getNumberFitPoints(theHist, fit, firstpoint);

	// update the cache;
	Cache* cache = new Cache(hist, npoints, firstpoint, fit);
	nagFit(cache, fit, quiet());
	delete cache;
    }


    void NagFitStrategy::fitTGraph(const TGraph& graph, TF1& fit) const throw (LogicError, MathsError) {
	TGraph& theGraph = const_cast<TGraph&> (graph);
	// use the range of fit to find the number of fit points
	// and the first active bin.
	vector<bool> active;
	getFitPoints(theGraph, fit, active);  // can use TGraph version.
	
	Cache* cache = new Cache(theGraph, fit, active);
	nagFit(cache, fit, quiet());
	delete cache;
    }

    void NagFitStrategy::fitTGraphErrors(const TGraphErrors& graph, TF1& fit) const throw (LogicError, MathsError) {
	TGraphErrors& theGraph = const_cast<TGraphErrors&> (graph);
	// use the range of fit to find the number of fit points
	// and the first active bin.
	vector<bool> active;
	getFitPoints(theGraph, fit, active);
	
	Cache* cache = new Cache(theGraph, fit, active);
	nagFit(cache, fit, quiet());
	delete cache;
    }

    void NagFitStrategy::nagFit(Cache* cache, TF1& fit, bool quiet) const throw(MathsError) {
	unsigned npoints=cache->m_x.size();
	if (npoints==0) return;  // don't throw -- just return.
	
	/// make the Nag variables
	static NagError fail, fail2; 
	Nag_Comm comm;
	Nag_E04_Opt nagOptions;

	//initialize options structure.
	e04xxc(&nagOptions);
	
	if (quiet){
	    nagOptions.list=false;
	    fail.print = FALSE;
	    nagOptions.print_level=Nag_NoPrint;
	    //strcpy(nagOptions.outfile,"Nag.out");
	}
	double chiSq=0.;
	double fjac[npoints][cache->nVarPars];
	double fvec[npoints];

	comm.p = (Pointer) cache;

	// Call the optimisation routine:
	nag_opt_lsq_no_deriv(npoints, cache->nVarPars,  
			     &(this->chiSquaredFunctionForNag), 
			     cache->inPars, &chiSq, fvec, 
			     (double*)fjac, cache->nVarPars, &nagOptions, &comm, &fail);
	
	e04xzc(&nagOptions, "all", &fail2);

	// check for problems:
	if (fail.code != NE_NOERROR && fail.code != NW_COND_MIN){
	    if (!quiet){
		MathsError e("NagFitStrategy NAG ERROR", __FILE__, __LINE__);
		e.sendToMrs(MRS_DIAGNOSTIC);
	    }
	} else {
	    cache->convertPars(cache->inPars);
	    for (int i=0; i<fit.GetNpar(); ++i){
		fit.SetParameter(i,cache->pars[i]);
		//cout << "Set Par " << i << " to  " << cache->pars[i] << endl;
	    }
	    fit.SetChisquare(chiSq);
	    fit.SetNDF(npoints-fit.GetNpar());
	}
    }

    const string& NagFitStrategy::getName() const throw() {
	return name;
    }
    
    bool NagFitStrategy::inMap=FitStrategyFactory::instance().addToMap("NagFitStrategy", *new NagFitStrategy("") );

    const int NagFitStrategy::getNumberFitPoints(TH1& hist, const TF1& fit, 
						 int& firstpoint) const throw (LogicError) {
	if (!ranged()){ // no 'R' so don't take range from TF1
	    firstpoint=1;
	    return hist.GetNbinsX();
	}

	int iLastBin;
	/// silly Root hasn't got a const TF1::GetRange()
	TF1& theFit = const_cast<TF1&> (fit);
	double xmin, xmax;
	theFit.GetRange(xmin, xmax);
	if (xmin>=xmax) throw InvalidArgumentError("NagFitStrategy TF1 fit min>=max", __FILE__, __LINE__);
	TAxis* axis=hist.GetXaxis();

	firstpoint = axis->FindFixBin(xmin); 
	iLastBin  = axis->FindFixBin(xmax); 
	if (firstpoint < 1) firstpoint = 1;
	if (iLastBin > axis->GetLast()) iLastBin = axis->GetLast();
	return iLastBin - firstpoint;
    }

    void NagFitStrategy::getFitPoints(TGraph& graph, const TF1& fit, vector<bool>& active) const throw (LogicError) {
	active.resize(graph.GetN());
	
	if (!ranged()){ // no 'R' so don't take range from TF1
	   for (int ipoint=0; ipoint<graph.GetN(); ++ipoint){ 
	       active[ipoint]=true;
	   }
	   return;
	}
	
	/// silly Root hasn't got a const TF1::GetRange()
	TF1& theFit = const_cast<TF1&> (fit);
	double xmin, xmax, x, y;
	theFit.GetRange(xmin, xmax);
	if (xmin>=xmax) throw InvalidArgumentError("NagFitStrategy TF1 fit min>=max", __FILE__, __LINE__);
	for (int ipoint=0; ipoint<graph.GetN(); ++ipoint){ 
	    graph.GetPoint(ipoint,x,y);
	    active[ipoint] = (x>=xmin && x<=xmax);
	}
    }

    NagFitStrategy::Cache::Cache(const TGraph& graph, TF1& fit, const vector<bool>& active) throw(LogicError, MathsError) {
	TGraph& theGraph = const_cast<TGraph&> (graph);
	function=&fit;
	setupPars(fit);
	for ( int ipoint=0; ipoint<graph.GetN(); ++ipoint){
	    if (active[ipoint]){
		double x, y;
		theGraph.GetPoint(ipoint,x,y);
		m_x.push_back(x);
		m_y.push_back(y);
		m_ey.push_back(1.); // set error to 1.
	    }
	}
    }

    NagFitStrategy::Cache::Cache(const TGraphErrors& graph, TF1& fit, const vector<bool>& active) throw(LogicError, MathsError) {
	TGraphErrors& theGraph=const_cast<TGraphErrors&> (graph);
	function=&fit;
	setupPars(fit);
	for ( int ipoint=0; ipoint<graph.GetN(); ++ipoint){
	    if (active[ipoint]){
		double x, y, err;
		theGraph.GetPoint(ipoint,x,y);
		err=theGraph.GetErrorY(ipoint);
		if (err==0.) throw MathsError("NagFitStrategy::Cache(graph) Error=0", __FILE__, __LINE__);
		m_x.push_back(x);
		m_y.push_back(y);
		m_ey.push_back(err);
	    }
	}
    }

    NagFitStrategy::Cache::Cache(const TH1& hist, const unsigned npoints, 
				 const unsigned firstpoint, TF1& fit) throw(LogicError, MathsError) {
	function=&fit;
	setupPars(fit);
	for (unsigned i=0; i<npoints; ++i){
	    double err = hist.GetBinError(i + firstpoint);
	    if (err==0.) throw MathsError("NagFitStrategy::Cache(TH1) Error=0", __FILE__, __LINE__);
	    double y=hist.GetBinContent(i + firstpoint);
	    if (2*err>y) continue;
	    m_ey.push_back(err);
	    m_y.push_back(y);
	    m_x.push_back(hist.GetBinCenter(i + firstpoint));
	}
    }
    
    NagFitStrategy::Cache::~Cache() throw(){
	delete [] pars;
	delete [] map;
	delete [] inPars;
    }
    
    void NagFitStrategy::Cache::convertPars(double inPars[]) {
	for (unsigned int i=0; i<nVarPars; ++i) {
	    pars[map[i]] = inPars[i];
	}
    }
	    	    
    void NagFitStrategy::Cache::setupPars(TF1& fit) {
	map = new int[fit.GetNpar()];
	pars = new double[fit.GetNpar()];
	inPars = new double[fit.GetNpar()];
	nVarPars = 0;
	double lowLim, highLim;
	
	for (unsigned int i=0; i<fit.GetNpar(); ++i) {
	    fit.GetParLimits(i, lowLim, highLim);
	    if (lowLim == highLim && lowLim != 0) {
		pars[i] = lowLim;
		//cout << "Fixing par: " << i << " to " << lowLim << endl;
	    } else {
		map[nVarPars] = i;
		inPars[nVarPars] = fit.GetParameter(i);
		//cout << "Par " << i << " variable. Initial val: " << inPars[nVarPars] << endl;		
		++nVarPars;
	    }
	}
	//cout << "NVarPar: " << nVarPars << endl;
    }


    /// NB set up cache first!
    void NagFitStrategy::chiSquaredFunctionForNag(Integer nresid, Integer nvar, double x_pars[], 
					    double fvec[], Nag_Comm *comm) throw() {	
	Cache* cache = (Cache*) comm->p;
	cache->convertPars(x_pars);
	for (int i=0; i < nresid; ++i){
	    double resid = cache->function->EvalPar( &(cache->m_x[i]) , cache->pars)
		- cache->m_y[i];
	    fvec[i] = (resid) / (cache->m_ey[i]);
	}
    }

    bool NagFitStrategy::quiet() const throw() {
	return ( getOptions().find('Q')!=string::npos || 
		 getOptions().find('q')!=string::npos );
    }
    
    bool NagFitStrategy::ranged() const throw() {
	return ( getOptions().find('R')!=string::npos || 
		 getOptions().find('r')!=string::npos );
    }

}  // end of namespace SctFitter;

