#include "tmvaglob.C"

// this macro prints out a neural network generated by MethodMLP graphically
// @author: Matt Jachowski, jachowski@stanford.edu

TFile* Network_GFile = 0;

static Int_t c_DarkBackground = TColor::GetColor( "#6e7a85" );


// input: - Input file (result from TMVA),
//        - use of TMVA plotting TStyle
void network( TString fin = "TMVA.root", Bool_t useTMVAStyle = kTRUE )
{
   // set style and remove existing canvas'
   TMVAGlob::Initialize( useTMVAStyle );

   // checks if file with name "fin" is already open, and if not opens one
   TFile* file = TMVAGlob::OpenFile( fin );  
   Network_GFile = file;

   TKey* mkey = TMVAGlob::FindMethod("MLP"); 
   if (mkey==0) {
      cout << "Could not locate directory MLP in file " << fin << endl;
      return;
   }
   TDirectory *dir = (TDirectory *)mkey->ReadObj();
   dir->cd();  
   TList titles;
   UInt_t ni = TMVAGlob::GetListOfTitles( dir, titles );
   if (ni==0) {
      cout << "No titles found for Method_MLP" << endl;
      return;
   }
   TIter nextTitle(&titles);
   TKey *titkey;
   TDirectory *titDir;
   while ((titkey = TMVAGlob::NextKey(nextTitle,"TDirectory"))) {
      titDir = (TDirectory *)titkey->ReadObj();
      cout << "Drawing title: " << titDir->GetName() << endl;
      draw_network(titDir);
   }
}

void draw_network(TDirectory* d)
{
   Bool_t __PRINT_LOGO__ = kTRUE;

   // create canvas
   TStyle* TMVAStyle = gROOT->GetStyle("TMVA"); // the TMVA style
   Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
   TMVAStyle->SetCanvasColor( c_DarkBackground );

   TCanvas* c = new TCanvas( "c", "Neural Network Layout", 100, 0, 1000, 650 );

   TIter next = d->GetListOfKeys();
   TKey *key;
   TString hName = "weights_hist";
   Int_t numHists = 0;

   // loop over all histograms with hName in name
   while (key = (TKey*)next()) {
      TClass *cl = gROOT->GetClass(key->GetClassName());
      if (!cl->InheritsFrom("TH2F")) continue;    
      TH2F *h = (TH2F*)key->ReadObj();    
      if (TString(h->GetName()).Contains( hName )) {
         numHists++;
      }
   }

   // loop over all histograms with hName in name again
   next.Reset();
   Double_t maxWeight = 0;

   // find max weight
   while (key = (TKey*)next()) {

      //cout << "Title: " << key->GetTitle() << endl;
      TClass *cl = gROOT->GetClass(key->GetClassName());
      if (!cl->InheritsFrom("TH2F")) continue;    

      TH2F* h = (TH2F*)key->ReadObj();    
      if (TString(h->GetName()).Contains( hName )){
 
         Int_t n1 = h->GetNbinsX();
         Int_t n2 = h->GetNbinsY();
         for (Int_t i = 0; i < n1; i++) {
            for (Int_t j = 0; j < n2; j++) {
               Double_t weight = TMath::Abs(h->GetBinContent(i+1, j+1));
               if (maxWeight < weight) maxWeight = weight;
            }
         }
      }
   }

   // draw network
   next.Reset();
   Int_t count = 0;
   while (key = (TKey*)next()) {

      TClass *cl = gROOT->GetClass(key->GetClassName());
      if (!cl->InheritsFrom("TH2F")) continue;    

      TH2F* h = (TH2F*)key->ReadObj();    
      if (TString(h->GetName()).Contains( hName )){
         draw_layer(c, h, count++, numHists+1, maxWeight);
      }
   }

   draw_layer_labels(numHists+1);

   // ============================================================
   if (__PRINT_LOGO__) TMVAGlob::plot_logo();
   // ============================================================  

   c->Update();

   TString fname = "plots/network";
   TMVAGlob::imgconv( c, fname );

   TMVAStyle->SetCanvasColor( canvasColor );
}

void draw_layer_labels(Int_t nLayers)
{
   const Double_t LABEL_HEIGHT = 0.03;
   const Double_t LABEL_WIDTH  = 0.20;
   Double_t effWidth = 0.8*(1.0-LABEL_WIDTH)/nLayers;
   Double_t height = 0.8*LABEL_HEIGHT;
   Double_t margY = LABEL_HEIGHT - height;

   for (Int_t i = 0; i < nLayers; i++) {
      TString label = Form("Layer %i", i);
      Double_t cx = i*(1.0-LABEL_WIDTH)/nLayers+1.0/(2.0*nLayers)+LABEL_WIDTH;
      Double_t x1 = cx-0.8*effWidth/2.0;
      Double_t x2 = cx+0.8*effWidth/2.0;
      Double_t y1 = margY;
      Double_t y2 = margY + height;

      TPaveLabel *p = new TPaveLabel(x1, y1, x2, y2, label+"", "br");
      p->SetFillColor(gStyle->GetTitleFillColor());
      p->SetFillStyle(1001);
      p->Draw();
   }
}

void draw_input_labels(Int_t nInputs, Double_t* cy, 
                       Double_t rad, Double_t layerWidth)
{
   const Double_t LABEL_HEIGHT = 0.03;
   const Double_t LABEL_WIDTH  = 0.20;
   Double_t width = LABEL_WIDTH + (layerWidth-4*rad);
   Double_t margX = 0.01;
   Double_t effHeight = 0.8*LABEL_HEIGHT;

   TString *varNames = get_var_names(nInputs);
   if (varNames == 0) exit(1);

   TString input;

   for (Int_t i = 0; i < nInputs; i++) {
      if (i != nInputs-1) input = varNames[i];
      else                input = "Bias node";
      Double_t x1 = margX;
      Double_t x2 = margX + width;
      Double_t y1 = cy[i] - effHeight;
      Double_t y2 = cy[i] + effHeight;

      TPaveLabel *p = new TPaveLabel(x1, y1, x2, y2, input+"", "br");
      p->SetFillColor(gStyle->GetTitleFillColor());
      p->SetFillStyle(1001);
      p->Draw();
      if (i == nInputs-1) p->SetTextColor(9);
   }

   delete[] varNames;
}

TString* get_var_names( Int_t nVars )
{
   const TString directories[3] = { "InputVariables_NoTransform",
                                    "InputVariables_DecorrTransform",
                                    "InputVariables_PCATransform" };

   TDirectory* dir = 0;
   for (Int_t i=0; i<3; i++) {
      dir = (TDirectory*)Network_GFile->Get( directories[i] );
      if (dir != 0) break;
   }
   if (dir==0) {
      cout << "*** Big troubles in macro \"network.C\": could not find directory for input variables, "
           << "and hence could not determine variable names --> abort" << endl;
      return 0;
   }
   cout << "--> go into directory: " << dir->GetName() << endl;
   dir->cd();

   TString* vars = new TString[nVars];
   Int_t ivar = 0;

   // loop over all objects in directory
   TIter next(dir->GetListOfKeys());
   TKey* key = 0;
   while ((key = (TKey*)next())) {
      if (key->GetCycle() != 1) continue;

      if(!TString(key->GetName()).Contains("__S")) continue;

      // make sure, that we only look at histograms
      TClass *cl = gROOT->GetClass(key->GetClassName());
      if (!cl->InheritsFrom("TH1")) continue;
      TH1 *sig = (TH1*)key->ReadObj();
      hname = sig->GetTitle();
      
      vars[ivar] = hname; ivar++;

      if (ivar > nVars-1) break;
   }      
   
   if (ivar != nVars-1) { // bias layer is also in nVars counts
      cout << "*** Troubles in \"network.C\": did not reproduce correct number of "
           << "input variables: " << ivar << " != " << nVars << endl;
   }

   return vars;

   // ------------- old way (not good) -------------

   //    TString fname = "weights/TMVAnalysis_MLP.weights.txt";
   //    ifstream fin( fname );
   //    if (!fin.good( )) { // file not found --> Error
   //       cout << "Error opening " << fname << endl;
   //       exit(1);
   //    }
   
   //    Int_t   idummy;
   //    Float_t fdummy;
   //    TString dummy = "";
   
   //    // file header with name
   //    while (!dummy.Contains("#VAR")) fin >> dummy;
   //    fin >> dummy >> dummy >> dummy; // the rest of header line
   
   //    // number of variables
   //    fin >> dummy >> idummy;
   //    // at this point, we should have idummy == nVars
   
   //    // variable mins and maxes
   //    TString* vars = new TString[nVars];
   //    for (Int_t i = 0; i < idummy; i++) fin >> vars[i] >> dummy >> dummy >> dummy;
   
   //    fin.close();
   
   //    return vars;
}

void draw_activation(TCanvas* c, Double_t cx, Double_t cy, 
                     Double_t radx, Double_t rady, Int_t whichActivation)
{
   TImage *activation = NULL;

   switch (whichActivation) {
   case 0:
      activation = TImage::Open("../macros/sigmoid-small.png");
      break;
   case 1:
      activation = TImage::Open("../macros/line-small.png");
      break;
   default:
      cout << "Activation index " << whichActivation << " is not known." << endl;
      cout << "You messed up or you need to modify network.C to introduce a new "
           << "activation function (and image) corresponding to this index" << endl;
   }

   if (activation == NULL) {
      cout << "Could not create an image... exit" << endl;
      return;
   }
  
   activation->SetConstRatio(kFALSE);

   radx *= 0.7;
   rady *= 0.7;
   TString name = Form("activation%f%f", cx, cy);
   TPad* p = new TPad(name+"", name+"", cx-radx, cy-rady, cx+radx, cy+rady);

   p->Draw();
   p->cd();

   activation->Draw();
   c->cd();
}

void draw_layer(TCanvas* c, TH2F* h, Int_t iHist, 
                Int_t nLayers, Double_t maxWeight)
{
   const Double_t MAX_NEURONS_NICE = 12;
   const Double_t LABEL_HEIGHT = 0.03;
   const Double_t LABEL_WIDTH  = 0.20;
   Double_t ratio = ((Double_t)(c->GetWindowHeight())) / c->GetWindowWidth();
   Double_t rad, cx1, *cy1, cx2, *cy2;

   // this is the smallest radius that will still display the activation images
   rad = 0.04*650/c->GetWindowHeight();

   Int_t nNeurons1 = h->GetNbinsX();
   cx1 = iHist*(1.0-LABEL_WIDTH)/nLayers + 1.0/(2.0*nLayers) + LABEL_WIDTH;
   cy1 = new Double_t[nNeurons1];

   Int_t nNeurons2 = h->GetNbinsY();
   cx2 = (iHist+1)*(1.0-LABEL_WIDTH)/nLayers + 1.0/(2.0*nLayers) + LABEL_WIDTH;
   cy2 = new Double_t[nNeurons2];

   Double_t effRad1 = rad;
   if (nNeurons1 > MAX_NEURONS_NICE)
      effRad1 = 0.8*(1.0-LABEL_HEIGHT)/(2.0*nNeurons1);


   for (Int_t i = 0; i < nNeurons1; i++) {
      cy1[nNeurons1-i-1] = i*(1.0-LABEL_HEIGHT)/nNeurons1 + 
         1.0/(2.0*nNeurons1) + LABEL_HEIGHT;

      if (iHist == 0) {

         TEllipse *ellipse 
            = new TEllipse(cx1, cy1[nNeurons1-i-1], 
                           effRad1*ratio, effRad1, 0, 360, 0);
         ellipse->SetFillColor(TColor::GetColor( "#fffffd" ));
         ellipse->SetFillStyle(1001);
         ellipse->Draw();

         if (i == 0) ellipse->SetLineColor(9);

         if (nNeurons1 > MAX_NEURONS_NICE) continue;

         Int_t whichActivation = 0;
         if (iHist==0 || iHist==nLayers-1 || i==0) whichActivation = 1;
         draw_activation(c, cx1, cy1[nNeurons1-i-1], 
                         rad*ratio, rad, whichActivation);
      }
   }

   if (iHist == 0) draw_input_labels(nNeurons1, cy1, rad, (1.0-LABEL_WIDTH)/nLayers);

   Double_t effRad2 = rad;
   if (nNeurons2 > MAX_NEURONS_NICE)
      effRad2 = 0.8*(1.0-LABEL_HEIGHT)/(2.0*nNeurons2);

   for (Int_t i = 0; i < nNeurons2; i++) {
      cy2[nNeurons2-i-1] = i*(1.0-LABEL_HEIGHT)/nNeurons2 + 1.0/(2.0*nNeurons2) + LABEL_HEIGHT;

      TEllipse *ellipse = 
         new TEllipse(cx2, cy2[nNeurons2-i-1], effRad2*ratio, effRad2, 0, 360, 0);
      ellipse->SetFillColor(TColor::GetColor( "#fffffd" ));
      ellipse->SetFillStyle(1001);
      ellipse->Draw();

      if (i == 0 && nNeurons2 > 1) ellipse->SetLineColor(9);

      if (nNeurons2 > MAX_NEURONS_NICE) continue;

      Int_t whichActivation = 0;
      if (iHist+1==0 || iHist+1==nLayers-1 || i==0) whichActivation = 1;
      draw_activation(c, cx2, cy2[nNeurons2-i-1], rad*ratio, rad, whichActivation);
   }

   for (Int_t i = 0; i < nNeurons1; i++) {
      for (Int_t j = 0; j < nNeurons2; j++) {
         draw_synapse(cx1, cy1[i], cx2, cy2[j], effRad1*ratio, effRad2*ratio,
                      h->GetBinContent(i+1, j+1)/maxWeight);
      }
   }

   delete[] cy1;
   delete[] cy2;
}

void draw_synapse(Double_t cx1, Double_t cy1, Double_t cx2, Double_t cy2,
                  Double_t  rad1, Double_t rad2, Double_t weightNormed)
{
   const Double_t TIP_SIZE   = 0.01;
   const Double_t MAX_WEIGHT = 8;
   const Double_t MAX_COLOR  = 100;  // red
   const Double_t MIN_COLOR  = 60;   // blue

   if (weightNormed == 0) return;

   //   gStyle->SetPalette(100, NULL);

   TArrow *arrow = new TArrow(cx1+rad1, cy1, cx2-rad2, cy2, TIP_SIZE, ">");
   arrow->SetFillColor(1);
   arrow->SetFillStyle(1001);
   arrow->SetLineWidth((Int_t)(TMath::Abs(weightNormed)*MAX_WEIGHT+0.5));
   arrow->SetLineColor((Int_t)((weightNormed+1.0)/2.0*(MAX_COLOR-MIN_COLOR)+MIN_COLOR+0.5));
   arrow->Draw();
}
