source: trunk/src/STFitter.cpp @ 2662

Last change on this file since 2662 was 2580, checked in by ShinnosukeKawakami, 12 years ago

hpc33 merged asap-trunk

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.3 KB
RevLine 
[91]1//#---------------------------------------------------------------------------
[890]2//# Fitter.cc: A Fitter class for spectra
[91]3//#--------------------------------------------------------------------------
[2444]4//# Copyright (C) 2004-2012
[125]5//# ATNF
[91]6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
[891]29//# $Id: STFitter.cpp 2580 2012-06-28 04:22:10Z ShinnosukeKawakami $
[91]30//#---------------------------------------------------------------------------
[125]31#include <casa/aips.h>
[91]32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
[1819]34#include <casa/Logging/LogIO.h>
[91]35#include <scimath/Fitting.h>
36#include <scimath/Fitting/LinearFit.h>
37#include <scimath/Functionals/CompiledFunction.h>
38#include <scimath/Functionals/CompoundFunction.h>
39#include <scimath/Functionals/Gaussian1D.h>
[2415]40#include <scimath/Functionals/Lorentzian1D.h>
[2047]41#include <scimath/Functionals/Sinusoid1D.h>
[91]42#include <scimath/Functionals/Polynomial.h>
43#include <scimath/Mathematics/AutoDiff.h>
44#include <scimath/Mathematics/AutoDiffMath.h>
45#include <scimath/Fitting/NonLinearFitLM.h>
46#include <components/SpectralComponents/SpectralEstimate.h>
47
[894]48#include "STFitter.h"
49
[91]50using namespace asap;
[125]51using namespace casa;
[91]52
[890]53Fitter::Fitter()
[91]54{
55}
56
[890]57Fitter::~Fitter()
[91]58{
[517]59  reset();
[91]60}
61
[890]62void Fitter::clear()
[91]63{
[517]64  for (uInt i=0;i< funcs_.nelements();++i) {
65    delete funcs_[i]; funcs_[i] = 0;
66  }
[612]67  funcs_.resize(0,True);
[517]68  parameters_.resize();
[1232]69  fixedpar_.resize();
[517]70  error_.resize();
71  thefit_.resize();
72  estimate_.resize();
73  chisquared_ = 0.0;
[91]74}
[517]75
[890]76void Fitter::reset()
[91]77{
[517]78  clear();
79  x_.resize();
80  y_.resize();
81  m_.resize();
[91]82}
83
84
[890]85bool Fitter::computeEstimate() {
[517]86  if (x_.nelements() == 0 || y_.nelements() == 0)
87    throw (AipsError("No x/y data specified."));
[91]88
[517]89  if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
90    return false;
91  uInt n = funcs_.nelements();
92  SpectralEstimate estimator(n);
93  estimator.setQ(5);
94  Int mn,mx;
95  mn = 0;
96  mx = m_.nelements()-1;
97  for (uInt i=0; i<m_.nelements();++i) {
98    if (m_[i]) {
99      mn = i;
100      break;
[108]101    }
[517]102  }
[2163]103  // use Int to suppress compiler warning
104  for (Int j=m_.nelements()-1; j>=0;--j) {
[517]105    if (m_[j]) {
106      mx = j;
107      break;
[108]108    }
[517]109  }
[1067]110  //mn = 0+x_.nelements()/10;
111  //mx = x_.nelements()-x_.nelements()/10;
[517]112  estimator.setRegion(mn,mx);
113  //estimator.setWindowing(True);
114  SpectralList listGauss = estimator.estimate(x_, y_);
115  parameters_.resize(n*3);
116  Gaussian1D<Float>* g = 0;
117  for (uInt i=0; i<n;i++) {
118    g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
119    if (g) {
[2445]120      const GaussianSpectralElement *gauss =
121        dynamic_cast<const GaussianSpectralElement *>(listGauss[i]) ;
122      (*g)[0] = gauss->getAmpl();
123      (*g)[1] = gauss->getCenter();
124      (*g)[2] = gauss->getFWHM();     
[2455]125      /*
[2444]126      (*g)[0] = listGauss[i].getAmpl();
127      (*g)[1] = listGauss[i].getCenter();
128      (*g)[2] = listGauss[i].getFWHM();
[2455]129      */
[91]130    }
[517]131  }
132  estimate_.resize();
133  listGauss.evaluate(estimate_,x_);
134  return true;
[91]135}
136
[890]137std::vector<float> Fitter::getEstimate() const
[91]138{
[517]139  if (estimate_.nelements() == 0)
140    throw (AipsError("No estimate set."));
141  std::vector<float> stlout;
142  estimate_.tovector(stlout);
143  return stlout;
[91]144}
145
146
[890]147bool Fitter::setExpression(const std::string& expr, int ncomp)
[91]148{
[517]149  clear();
150  if (expr == "gauss") {
151    if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
152    funcs_.resize(ncomp);
[1932]153    funcnames_.clear();
154    funccomponents_.clear();
[517]155    for (Int k=0; k<ncomp; ++k) {
156      funcs_[k] = new Gaussian1D<Float>();
[1932]157      funcnames_.push_back(expr);
158      funccomponents_.push_back(3);
[517]159    }
[1819]160  } else if (expr == "lorentz") {
161    if (ncomp < 1) throw (AipsError("Need at least one lorentzian to fit."));
162    funcs_.resize(ncomp);
[1932]163    funcnames_.clear();
164    funccomponents_.clear();
[1819]165    for (Int k=0; k<ncomp; ++k) {
166      funcs_[k] = new Lorentzian1D<Float>();
[1932]167      funcnames_.push_back(expr);
168      funccomponents_.push_back(3);
[1819]169    }
[2047]170  } else if (expr == "sinusoid") {
171    if (ncomp < 1) throw (AipsError("Need at least one sinusoid to fit."));
172    funcs_.resize(ncomp);
173    funcnames_.clear();
174    funccomponents_.clear();
175    for (Int k=0; k<ncomp; ++k) {
176      funcs_[k] = new Sinusoid1D<Float>();
177      funcnames_.push_back(expr);
178      funccomponents_.push_back(3);
179    }
180  } else if (expr == "poly") {
181    funcs_.resize(1);
182    funcnames_.clear();
183    funccomponents_.clear();
184    funcs_[0] = new Polynomial<Float>(ncomp);
185      funcnames_.push_back(expr);
186      funccomponents_.push_back(ncomp);
[517]187  } else {
[1819]188    LogIO os( LogOrigin( "Fitter", "setExpression()", WHERE ) ) ;
189    os << LogIO::WARN << " compiled functions not yet implemented" << LogIO::POST;
[517]190    //funcs_.resize(1);
191    //funcs_[0] = new CompiledFunction<Float>();
192    //funcs_[0]->setFunction(String(expr));
193    return false;
194  }
195  return true;
[91]196}
197
[890]198bool Fitter::setData(std::vector<float> absc, std::vector<float> spec,
[91]199                       std::vector<bool> mask)
200{
201    x_.resize();
202    y_.resize();
203    m_.resize();
204    // convert std::vector to casa Vector
205    Vector<Float> tmpx(absc);
206    Vector<Float> tmpy(spec);
207    Vector<Bool> tmpm(mask);
208    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
209    x_ = tmpx;
210    y_ = tmpy;
211    m_ = tmpm;
212    return true;
213}
214
[890]215std::vector<float> Fitter::getResidual() const
[91]216{
217    if (residual_.nelements() == 0)
218        throw (AipsError("Function not yet fitted."));
219    std::vector<float> stlout;
220    residual_.tovector(stlout);
221    return stlout;
222}
223
[890]224std::vector<float> Fitter::getFit() const
[91]225{
226    Vector<Float> out = thefit_;
227    std::vector<float> stlout;
228    out.tovector(stlout);
229    return stlout;
230
231}
232
[890]233std::vector<float> Fitter::getErrors() const
[91]234{
235    Vector<Float> out = error_;
236    std::vector<float> stlout;
237    out.tovector(stlout);
238    return stlout;
239}
240
[890]241bool Fitter::setParameters(std::vector<float> params)
[91]242{
243    Vector<Float> tmppar(params);
244    if (funcs_.nelements() == 0)
245        throw (AipsError("Function not yet set."));
246    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
247        throw (AipsError("Number of parameters inconsistent with function."));
[1232]248    if (parameters_.nelements() == 0) {
[91]249        parameters_.resize(tmppar.nelements());
[1232]250        if (tmppar.nelements() != fixedpar_.nelements()) {
251            fixedpar_.resize(tmppar.nelements());
252            fixedpar_ = False;
253        }
254    }
[91]255    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
256        uInt count = 0;
257        for (uInt j=0; j < funcs_.nelements(); ++j) {
258            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
259                (funcs_[j]->parameters())[i] = tmppar[count];
260                parameters_[count] = tmppar[count];
261                ++count;
262            }
263        }
[1819]264    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
265        uInt count = 0;
266        for (uInt j=0; j < funcs_.nelements(); ++j) {
267            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
268                (funcs_[j]->parameters())[i] = tmppar[count];
269                parameters_[count] = tmppar[count];
270                ++count;
271            }
272        }
[2047]273    } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
274        uInt count = 0;
275        for (uInt j=0; j < funcs_.nelements(); ++j) {
276            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
277                (funcs_[j]->parameters())[i] = tmppar[count];
278                parameters_[count] = tmppar[count];
279                ++count;
280            }
281        }
282    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
283        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
284            parameters_[i] = tmppar[i];
285            (funcs_[0]->parameters())[i] =  tmppar[i];
286        }
[91]287    }
[1232]288    // reset
289    if (params.size() == 0) {
290        parameters_.resize();
291        fixedpar_.resize();
292    }
[91]293    return true;
294}
295
[890]296bool Fitter::setFixedParameters(std::vector<bool> fixed)
[91]297{
298    if (funcs_.nelements() == 0)
299        throw (AipsError("Function not yet set."));
[1232]300    if (fixedpar_.nelements() > 0 && fixed.size() != fixedpar_.nelements())
[91]301        throw (AipsError("Number of mask elements inconsistent with function."));
[1232]302    if (fixedpar_.nelements() == 0) {
303        fixedpar_.resize(parameters_.nelements());
304        fixedpar_ = False;
305    }
[91]306    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
307        uInt count = 0;
308        for (uInt j=0; j < funcs_.nelements(); ++j) {
309            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
[1232]310                funcs_[j]->mask(i) = !fixed[count];
311                fixedpar_[count] = fixed[count];
[91]312                ++count;
313            }
314        }
[1819]315    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
316      uInt count = 0;
317        for (uInt j=0; j < funcs_.nelements(); ++j) {
318            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
319                funcs_[j]->mask(i) = !fixed[count];
320                fixedpar_[count] = fixed[count];
321                ++count;
322            }
323        }
[2047]324    } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
325      uInt count = 0;
326        for (uInt j=0; j < funcs_.nelements(); ++j) {
327            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
328                funcs_[j]->mask(i) = !fixed[count];
329                fixedpar_[count] = fixed[count];
330                ++count;
331            }
332        }
333    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
334        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
335            fixedpar_[i] = fixed[i];
336            funcs_[0]->mask(i) =  !fixed[i];
337        }
[91]338    }
339    return true;
340}
341
[890]342std::vector<float> Fitter::getParameters() const {
[91]343    Vector<Float> out = parameters_;
344    std::vector<float> stlout;
345    out.tovector(stlout);
346    return stlout;
347}
348
[890]349std::vector<bool> Fitter::getFixedParameters() const {
[108]350  Vector<Bool> out(parameters_.nelements());
351  if (fixedpar_.nelements() == 0) {
[1232]352    return std::vector<bool>();
[108]353    //throw (AipsError("No parameter mask set."));
354  } else {
355    out = fixedpar_;
356  }
357  std::vector<bool> stlout;
358  out.tovector(stlout);
359  return stlout;
[91]360}
361
[890]362float Fitter::getChisquared() const {
[91]363    return chisquared_;
364}
365
[890]366bool Fitter::fit() {
[517]367  NonLinearFitLM<Float> fitter;
368  CompoundFunction<Float> func;
[612]369
370  uInt n = funcs_.nelements();
[517]371  for (uInt i=0; i<n; ++i) {
372    func.addFunction(*funcs_[i]);
373  }
[612]374
[517]375  fitter.setFunction(func);
376  fitter.setMaxIter(50+n*10);
377  // Convergence criterium
378  fitter.setCriteria(0.001);
[612]379
[517]380  // Fit
[2580]381//   Vector<Float> sigma(x_.nelements());
382//   sigma = 1.0;
[890]383
[517]384  parameters_.resize();
[2580]385//   parameters_ = fitter.fit(x_, y_, sigma, &m_);
386  parameters_ = fitter.fit(x_, y_, &m_); 
[1067]387  if ( !fitter.converged() ) {
388     return false;
389  }
[517]390  std::vector<float> ps;
391  parameters_.tovector(ps);
392  setParameters(ps);
[612]393
[517]394  error_.resize();
395  error_ = fitter.errors();
[612]396
[517]397  chisquared_ = fitter.getChi2();
[890]398
[2580]399//   residual_.resize();
400//   residual_ =  y_;
401//   fitter.residual(residual_,x_);
[517]402  // use fitter.residual(model=True) to get the model
403  thefit_.resize(x_.nelements());
404  fitter.residual(thefit_,x_,True);
[2580]405  // residual = data - model
406  residual_.resize(x_.nelements());
407  residual_ = y_ - thefit_ ;
[517]408  return true;
409}
[483]410
[1391]411bool Fitter::lfit() {
412  LinearFit<Float> fitter;
413  CompoundFunction<Float> func;
[483]414
[1391]415  uInt n = funcs_.nelements();
416  for (uInt i=0; i<n; ++i) {
417    func.addFunction(*funcs_[i]);
418  }
419
420  fitter.setFunction(func);
421  //fitter.setMaxIter(50+n*10);
422  // Convergence criterium
423  //fitter.setCriteria(0.001);
424
425  // Fit
[2580]426//   Vector<Float> sigma(x_.nelements());
427//   sigma = 1.0;
[1391]428
429  parameters_.resize();
[2580]430//   parameters_ = fitter.fit(x_, y_, sigma, &m_);
431  parameters_ = fitter.fit(x_, y_, &m_);
[1391]432  std::vector<float> ps;
433  parameters_.tovector(ps);
434  setParameters(ps);
435
436  error_.resize();
437  error_ = fitter.errors();
438
439  chisquared_ = fitter.getChi2();
440
[2580]441//   residual_.resize();
442//   residual_ =  y_;
443//   fitter.residual(residual_,x_);
[1391]444  // use fitter.residual(model=True) to get the model
445  thefit_.resize(x_.nelements());
446  fitter.residual(thefit_,x_,True);
[2580]447  // residual = data - model
448  residual_.resize(x_.nelements());
449  residual_ = y_ - thefit_ ;
[1391]450  return true;
451}
452
[890]453std::vector<float> Fitter::evaluate(int whichComp) const
454{
[517]455  std::vector<float> stlout;
[890]456  uInt idx = uInt(whichComp);
[517]457  Float y;
458  if ( idx < funcs_.nelements() ) {
459    for (uInt i=0; i<x_.nelements(); ++i) {
460      y = (*funcs_[idx])(x_[i]);
461      stlout.push_back(float(y));
462    }
463  }
464  return stlout;
465}
[483]466
[1932]467STFitEntry Fitter::getFitEntry() const
468{
469  STFitEntry fit;
470  fit.setParameters(getParameters());
471  fit.setErrors(getErrors());
472  fit.setComponents(funccomponents_);
473  fit.setFunctions(funcnames_);
474  fit.setParmasks(getFixedParameters());
475  return fit;
476}
Note: See TracBrowser for help on using the repository browser.