/***************************************************************************
 *                                                                         *
 *                  (begin: Feb 20 2003)                                   *
 *                                                                         *
 *   Parallel IQPNNI - Important Quartet Puzzle with NNI                   *
 *                                                                         *
 *   Copyright (C) 2005 by Le Sy Vinh, Bui Quang Minh, Arndt von Haeseler  *
 *   Copyright (C) 2003-2004 by Le Sy Vinh, Arndt von Haeseler             *
 *   {vinh,minh}@cs.uni-duesseldorf.de                                     *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include "testmethod.h"

#include <cstring>
#include <ctime>
#include <fstream>
#include <math.h>
#include <time.h>
#include <ctime>
#include <stdio.h>

#include "testmethod.h"
#include "constant.h"
#include "rate.h"
#include "model.h"
#include "ali.h"
#include "utl.h"
#include "outstream.h"
#include "usertree.h"
#include "ptnls.h"
#include "treels.h"
#include "interface.h"
#include "iqp.h"

#ifdef PARALLEL
//#include <mpi.h>
#endif

using namespace std;

//--------------------------------------------------------------------
//we will init all things here before starting testing this method with data
void TestMethod::init () {}

//--------------------------------------------------------------------
//set all paramaters of this class for Model class, Rate class
void TestMethod::setPam (CommonParameters &params) {

	myrate.nsSy_classes = params.nsSy_classes;
	myrate.nsSy_ratio_type = params.nsSy_ratio_type;
	memmove(myrate.nsSy_ratio_val, params.nsSy_ratio_val, params.nsSy_classes * sizeof(double));
	memmove(myrate.nsSy_prob_val, params.nsSy_prob_val, params.nsSy_classes * sizeof(double));


	mymodel.codon_model = params.codon_model;
	mymodel.codonFrqType = params.codonFrqType;
	mymodel.codon_tsTvRatio = params.codon_tsTvRatio;
	mymodel.codon_tsTvRatioType = params.codon_tsTvRatioType;

	mymodel.setDataType (params.dataType);
	mymodel.setModelType (params.model);


	mymodel.setTsTvRatio (params.tsTvRatio * 2.0);
	mymodel.setTsTvRatioType (params.tsTvRatioType);

	mymodel.setPyPuRatio (params.pyPuRatio);
	mymodel.setPyPuRatioType (params.pyPuRatioType);

	mymodel.setGenPam (params.tsAG, params.tsCT, params.tvAC, params.tvAT, params.tvCG, params.tvGT);
	mymodel.setGenPamType (params.genPamType);

	mymodel.setStateFrq (params.baseA, params.baseC, params.baseG, params.baseT);
	mymodel.setBaseFrqType (params.baseFrqType);
	mymodel.pam_brent = params.pam_brent;
	mymodel.rate_file = params.rate_file;
	mymodel.ap_sitespec = params.ap_sitespec;

	mymodel.init ();

	myrate.init ();
	myrate.setType (params.rateType);
	if (params.rateType == GAMMA)
		myrate.setNRate (params.nRate);
	myrate.setGammaShape (params.gammaShape);
	myrate.setGammaShapeType (params.gammaShapeType);
	//myrate.use_invar_site = params.invariable_site;
	myrate.prob_invar_site = params.prob_invar;
	myrate.invar_site_type = params.prob_invar_type;
	//myrate.nsSy_classes = nsSy_classes;
	myrate.create ();
}


char nto_char[NUM_CHAR];

void create_nto_char() {
	for (int i = 0; i < NUM_CHAR; i++)
		nto_char[(unsigned char)nto[i]] = i;
	nto_char[BS_UNKNOWN] = '-';
}

/**
	write pattern frequencies and logl to a file
*/
void writePatternLogl() {

	create_nto_char();

	string patlh_name;
	getOutFileName(SUF_PATLH, patlh_name);

	ofstream out(patlh_name.c_str());
	if (!out.is_open()) {	
		patlh_name = "Cannot write patterns-loglikelihood to file " + patlh_name ;
		Utl::announceError(patlh_name .c_str());
	}

	int nPtn_ = ptnlist.getNPtn ();

	// compute the site-loglikelihood
	createPtnLiInfo();
	opt_urtree.cmpLiNd();
	opt_urtree.cmpLogLi();


	//writing output file
	for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		int nseqs = alignment.getNSeq();
		for (int i = 0; i < nseqs; i++)
			out << nto_char[(int)ptnlist.getBase(ptnNo_, i)];
		out << "  " << ptnlist.weightArr_[ptnNo_] << "  " << ptn_logl[ptnNo_] << endl;
	}

	out.close();
	cout << "Pattern loglikelihood is printed to " << patlh_name << endl;	

	deletePtnLiInfo();
}


/**
	write site-logl to a file
*/
void writeSiteLogl() {

	string sitelh_name;
	getOutFileName(SUF_SITELH, sitelh_name);
	ofstream out(sitelh_name.c_str());
	if (!out.is_open()) {	
		sitelh_name = "Cannot write site-loglikelihood to file " + sitelh_name;
		Utl::announceError(sitelh_name.c_str());
	}

	int nSite_ = alignment.getNSite();

	// compute the site-loglikelihood
	createPtnLiInfo();
	opt_urtree.cmpLiNd();
	opt_urtree.cmpLogLi();

	create_nto_char();
	out << "1  " << nSite_ << endl << "iqp_tree  ";
	out.precision(10);

	for (int siteNo_ = 0; siteNo_ < nSite_; siteNo_ ++) {
		int ptnNo_ = alignment.getPtn(siteNo_);
		//writing output file
		out << ptn_logl[ptnNo_] << "  ";
	}

	out.close();
	cout << "Site-loglikelihood is printed to " << sitelh_name << endl;	
	deletePtnLiInfo();
}


//--------------------------------------------------------------------
/**
the main method of IQPNNI algorithm
*/
int TestMethod::create (CommonParameters &params, InputParameters &in_pam, time_t &beginTime_) {


	char *startedDate_;
	startedDate_ = ctime(&beginTime_);

	int nSeq_;
	nSeq_ = alignment.getNSeq ();

	
	Vec<char> conTree_;

	IQP iqp_;



	//create the pattern list
	ptnlist.release();
	ptnlist.init ();
	ptnlist.create ();
	if (isMasterProc())
		std::cout << "Number patterns = " << ptnlist.getNPtn () << endl;
	//   ptnlist.write (std::cout);

	//}

	setPam (params);

	
	if (mymodel.isCodonAnalysisNext()) {
		if (isMasterProc())
			cout << "===> First use nucleotide substitution model to search for tree topology..." << endl;
	} 
	alignment.init ();

	params.nRep = Utl::getMin (params.nRep + 1, MAX_NUM_REP - 1);

	// main optimization step: remove nodes, reinsert by IQP, branch swapping by NNI
	UserTree tree_;
	iqp_.escapeLocalOpt (params, in_pam, tree_, conTree_, beginTime_);



	if (mymodel.isCodonAnalysisNext()) {
		// free memory
		opt_urtree.liBrArr_.release();
		//mymodel.setTsTvRatioType(USER_DEFINED);
		mymodel.initialized_tsTv = true;
		startCodonAnalysis();
		tree_.convertBrLenToCodon();

		//userTree_.draw (BR_LEN, 0, std::cout);
		double userLogLi_ ;
		//tree_.createOptUrTree ();
		bool topo_changed;


		alignment.isGenDisMatCmped_ = 1;
		userLogLi_ =  tree_.doConOpt (0, topo_changed, 1);
		if (isMasterProc())
			cout << "Initial LogL = " << userLogLi_ << endl;
		int count = 0;
		do {
			count++;
			iqp_.optPam();
			if (isMasterProc())
				std::cout <<"Optimizing the tree topology as well as branch lengths..." << endl;

			// write the imediate tree to file
			char treeout[FILENAME_LEN];
			sprintf(treeout, "%s.iqpnni.tree%d", alignment.out_prefix, count);
			tree_.writeTop(treeout, 0);

			userLogLi_ = tree_.doConOpt(1, topo_changed);

			if (topo_changed) {
				if (isMasterProc())
					cout <<"Tree topology was changed, reoptimize paramater(s)..." << endl;
				tree_.createOptUrTree();
			}
		} while (topo_changed);

		cout.precision(10);
		if (isMasterProc())
			cout <<"Log Likelihood = " << userLogLi_ << endl;


		time_t currentTime_;
		time(&currentTime_);

		params.progTime = difftime (currentTime_, beginTime_);
	}


	Interface interface_;
	tree_.createRootedTree(params.outGrpSeqNo);

	//tree_.writeNewickForm(cout);

	if (isMasterProc()) {
		interface_.writeOut (params, in_pam, tree_, conTree_, startedDate_);

		if (params.gen_bootstrap) {


			string boottreeFileName_;
			getOutFileName(SUF_BOOTTREE, boottreeFileName_);
			if (params.cur_bootstrap == 0)
				tree_.writeTop(boottreeFileName_.c_str(), 0);
			else
				tree_.writeTop(boottreeFileName_.c_str(), 1);
		}
		if (params.print_pat_logl) {
			writePatternLogl();
		}
		if (params.print_site_logl) {
			writeSiteLogl();
		}
	}

	//ptnlist.clean ();
	//alignment.clean ();


	return 1;
}

void TestMethod::startNextBootstrap() {
	ptnlist.release();
}


void TestMethod::startCodonAnalysis() {
	if (!mymodel.isCodonAnalysisNext())
		return;

	if (isMasterProc())
		std::cout << "===> Now run codon analysis" << endl;

	alignment.convertDataType(CODON);
	ptnlist.release();
	ptnlist.init ();
	ptnlist.create ();
	if (isMasterProc())
		std::cout << "Number patterns = " << ptnlist.getNPtn () << endl;

	mymodel.setTsTvRatio(mymodel.model_->getTsTvRatio());

	mymodel.release();
	mymodel.setDataType (CODON);
	mymodel.setModelType (mymodel.codon_model);


	mymodel.setTsTvRatioType (mymodel.codon_tsTvRatioType);
	if (mymodel.getTsTvRatioType() == USER_DEFINED)
		mymodel.setTsTvRatio (mymodel.codon_tsTvRatio * 2.0);

	mymodel.setBaseFrqType (mymodel.codonFrqType);

	mymodel.init ();

}

