source: branches/mergetest/src/STFitter.cpp @ 1779

Last change on this file since 1779 was 1779, checked in by Kana Sugimoto, 14 years ago

New Development: Yes

JIRA Issue: No (test merging alma branch)

Ready for Test: Yes

Interface Changes: Yes

What Interface Changed:

Test Programs:

Put in Release Notes: No

Module(s):

Description:


  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 11.1 KB
RevLine 
[91]1//#---------------------------------------------------------------------------
[890]2//# Fitter.cc: A Fitter class for spectra
[91]3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004
[125]5//# ATNF
[91]6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
[891]29//# $Id: STFitter.cpp 1779 2010-07-29 09:13:46Z KanaSugimoto $
[91]30//#---------------------------------------------------------------------------
[125]31#include <casa/aips.h>
[91]32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
[1779]34#include <casa/Logging/LogIO.h>
[91]35#include <scimath/Fitting.h>
36#include <scimath/Fitting/LinearFit.h>
37#include <scimath/Functionals/CompiledFunction.h>
38#include <scimath/Functionals/CompoundFunction.h>
39#include <scimath/Functionals/Gaussian1D.h>
[1779]40#include "Lorentzian1D.h"
[91]41#include <scimath/Functionals/Polynomial.h>
42#include <scimath/Mathematics/AutoDiff.h>
43#include <scimath/Mathematics/AutoDiffMath.h>
44#include <scimath/Fitting/NonLinearFitLM.h>
45#include <components/SpectralComponents/SpectralEstimate.h>
46
[894]47#include "STFitter.h"
48
[91]49using namespace asap;
[125]50using namespace casa;
[91]51
[890]52Fitter::Fitter()
[91]53{
54}
55
[890]56Fitter::~Fitter()
[91]57{
[517]58  reset();
[91]59}
60
[890]61void Fitter::clear()
[91]62{
[517]63  for (uInt i=0;i< funcs_.nelements();++i) {
64    delete funcs_[i]; funcs_[i] = 0;
65  }
[612]66  funcs_.resize(0,True);
[517]67  parameters_.resize();
[1232]68  fixedpar_.resize();
[517]69  error_.resize();
70  thefit_.resize();
71  estimate_.resize();
72  chisquared_ = 0.0;
[91]73}
[517]74
[890]75void Fitter::reset()
[91]76{
[517]77  clear();
78  x_.resize();
79  y_.resize();
80  m_.resize();
[91]81}
82
83
[890]84bool Fitter::computeEstimate() {
[517]85  if (x_.nelements() == 0 || y_.nelements() == 0)
86    throw (AipsError("No x/y data specified."));
[91]87
[517]88  if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
89    return false;
90  uInt n = funcs_.nelements();
91  SpectralEstimate estimator(n);
92  estimator.setQ(5);
93  Int mn,mx;
94  mn = 0;
95  mx = m_.nelements()-1;
96  for (uInt i=0; i<m_.nelements();++i) {
97    if (m_[i]) {
98      mn = i;
99      break;
[108]100    }
[517]101  }
102  for (uInt j=m_.nelements()-1; j>=0;--j) {
103    if (m_[j]) {
104      mx = j;
105      break;
[108]106    }
[517]107  }
[1067]108  //mn = 0+x_.nelements()/10;
109  //mx = x_.nelements()-x_.nelements()/10;
[517]110  estimator.setRegion(mn,mx);
111  //estimator.setWindowing(True);
112  SpectralList listGauss = estimator.estimate(x_, y_);
113  parameters_.resize(n*3);
114  Gaussian1D<Float>* g = 0;
115  for (uInt i=0; i<n;i++) {
116    g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
117    if (g) {
118      (*g)[0] = listGauss[i].getAmpl();
119      (*g)[1] = listGauss[i].getCenter();
120      (*g)[2] = listGauss[i].getFWHM();
[91]121    }
[517]122  }
123  estimate_.resize();
124  listGauss.evaluate(estimate_,x_);
125  return true;
[91]126}
127
[890]128std::vector<float> Fitter::getEstimate() const
[91]129{
[517]130  if (estimate_.nelements() == 0)
131    throw (AipsError("No estimate set."));
132  std::vector<float> stlout;
133  estimate_.tovector(stlout);
134  return stlout;
[91]135}
136
137
[890]138bool Fitter::setExpression(const std::string& expr, int ncomp)
[91]139{
[517]140  clear();
141  if (expr == "gauss") {
142    if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
143    funcs_.resize(ncomp);
144    for (Int k=0; k<ncomp; ++k) {
145      funcs_[k] = new Gaussian1D<Float>();
146    }
147  } else if (expr == "poly") {
148    funcs_.resize(1);
149    funcs_[0] = new Polynomial<Float>(ncomp);
[1779]150  } else if (expr == "lorentz") {
151    if (ncomp < 1) throw (AipsError("Need at least one lorentzian to fit."));
152    funcs_.resize(ncomp);
153    for (Int k=0; k<ncomp; ++k) {
154      funcs_[k] = new Lorentzian1D<Float>();
155    }
[517]156  } else {
[1779]157    //cerr << " compiled functions not yet implemented" << endl;
158    LogIO os( LogOrigin( "Fitter", "setExpression()", WHERE ) ) ;
159    os << LogIO::WARN << " compiled functions not yet implemented" << LogIO::POST;
[517]160    //funcs_.resize(1);
161    //funcs_[0] = new CompiledFunction<Float>();
162    //funcs_[0]->setFunction(String(expr));
163    return false;
164  }
165  return true;
[91]166}
167
[890]168bool Fitter::setData(std::vector<float> absc, std::vector<float> spec,
[91]169                       std::vector<bool> mask)
170{
171    x_.resize();
172    y_.resize();
173    m_.resize();
174    // convert std::vector to casa Vector
175    Vector<Float> tmpx(absc);
176    Vector<Float> tmpy(spec);
177    Vector<Bool> tmpm(mask);
178    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
179    x_ = tmpx;
180    y_ = tmpy;
181    m_ = tmpm;
182    return true;
183}
184
[890]185std::vector<float> Fitter::getResidual() const
[91]186{
187    if (residual_.nelements() == 0)
188        throw (AipsError("Function not yet fitted."));
189    std::vector<float> stlout;
190    residual_.tovector(stlout);
191    return stlout;
192}
193
[890]194std::vector<float> Fitter::getFit() const
[91]195{
196    Vector<Float> out = thefit_;
197    std::vector<float> stlout;
198    out.tovector(stlout);
199    return stlout;
200
201}
202
[890]203std::vector<float> Fitter::getErrors() const
[91]204{
205    Vector<Float> out = error_;
206    std::vector<float> stlout;
207    out.tovector(stlout);
208    return stlout;
209}
210
[890]211bool Fitter::setParameters(std::vector<float> params)
[91]212{
213    Vector<Float> tmppar(params);
214    if (funcs_.nelements() == 0)
215        throw (AipsError("Function not yet set."));
216    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
217        throw (AipsError("Number of parameters inconsistent with function."));
[1232]218    if (parameters_.nelements() == 0) {
[91]219        parameters_.resize(tmppar.nelements());
[1232]220        if (tmppar.nelements() != fixedpar_.nelements()) {
221            fixedpar_.resize(tmppar.nelements());
222            fixedpar_ = False;
223        }
224    }
[91]225    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
226        uInt count = 0;
227        for (uInt j=0; j < funcs_.nelements(); ++j) {
228            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
229                (funcs_[j]->parameters())[i] = tmppar[count];
230                parameters_[count] = tmppar[count];
231                ++count;
232            }
233        }
234    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
235        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
236            parameters_[i] = tmppar[i];
237            (funcs_[0]->parameters())[i] =  tmppar[i];
238        }
[1779]239    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
240        uInt count = 0;
241        for (uInt j=0; j < funcs_.nelements(); ++j) {
242            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
243                (funcs_[j]->parameters())[i] = tmppar[count];
244                parameters_[count] = tmppar[count];
245                ++count;
246            }
247        }
[91]248    }
[1232]249    // reset
250    if (params.size() == 0) {
251        parameters_.resize();
252        fixedpar_.resize();
253    }
[91]254    return true;
255}
256
[890]257bool Fitter::setFixedParameters(std::vector<bool> fixed)
[91]258{
259    if (funcs_.nelements() == 0)
260        throw (AipsError("Function not yet set."));
[1232]261    if (fixedpar_.nelements() > 0 && fixed.size() != fixedpar_.nelements())
[91]262        throw (AipsError("Number of mask elements inconsistent with function."));
[1232]263    if (fixedpar_.nelements() == 0) {
264        fixedpar_.resize(parameters_.nelements());
265        fixedpar_ = False;
266    }
[91]267    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
268        uInt count = 0;
269        for (uInt j=0; j < funcs_.nelements(); ++j) {
270            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
[1232]271                funcs_[j]->mask(i) = !fixed[count];
272                fixedpar_[count] = fixed[count];
[91]273                ++count;
274            }
275        }
276    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
277        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
[1232]278            fixedpar_[i] = fixed[i];
279            funcs_[0]->mask(i) =  !fixed[i];
[91]280        }
[1779]281    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
282      uInt count = 0;
283        for (uInt j=0; j < funcs_.nelements(); ++j) {
284            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
285                funcs_[j]->mask(i) = !fixed[count];
286                fixedpar_[count] = fixed[count];
287                ++count;
288            }
289        }
[91]290    }
291    return true;
292}
293
[890]294std::vector<float> Fitter::getParameters() const {
[91]295    Vector<Float> out = parameters_;
296    std::vector<float> stlout;
297    out.tovector(stlout);
298    return stlout;
299}
300
[890]301std::vector<bool> Fitter::getFixedParameters() const {
[108]302  Vector<Bool> out(parameters_.nelements());
303  if (fixedpar_.nelements() == 0) {
[1232]304    return std::vector<bool>();
[108]305    //throw (AipsError("No parameter mask set."));
306  } else {
307    out = fixedpar_;
308  }
309  std::vector<bool> stlout;
310  out.tovector(stlout);
311  return stlout;
[91]312}
313
[890]314float Fitter::getChisquared() const {
[91]315    return chisquared_;
316}
317
[890]318bool Fitter::fit() {
[517]319  NonLinearFitLM<Float> fitter;
320  CompoundFunction<Float> func;
[612]321
322  uInt n = funcs_.nelements();
[517]323  for (uInt i=0; i<n; ++i) {
324    func.addFunction(*funcs_[i]);
325  }
[612]326
[517]327  fitter.setFunction(func);
328  fitter.setMaxIter(50+n*10);
329  // Convergence criterium
330  fitter.setCriteria(0.001);
[612]331
[517]332  // Fit
333  Vector<Float> sigma(x_.nelements());
334  sigma = 1.0;
[890]335
[517]336  parameters_.resize();
337  parameters_ = fitter.fit(x_, y_, sigma, &m_);
[1067]338  if ( !fitter.converged() ) {
339     return false;
340  }
[517]341  std::vector<float> ps;
342  parameters_.tovector(ps);
343  setParameters(ps);
[612]344
[517]345  error_.resize();
346  error_ = fitter.errors();
[612]347
[517]348  chisquared_ = fitter.getChi2();
[890]349
[517]350  residual_.resize();
351  residual_ =  y_;
352  fitter.residual(residual_,x_);
353  // use fitter.residual(model=True) to get the model
354  thefit_.resize(x_.nelements());
355  fitter.residual(thefit_,x_,True);
356  return true;
357}
[483]358
[1391]359bool Fitter::lfit() {
360  LinearFit<Float> fitter;
361  CompoundFunction<Float> func;
[483]362
[1391]363  uInt n = funcs_.nelements();
364  for (uInt i=0; i<n; ++i) {
365    func.addFunction(*funcs_[i]);
366  }
367
368  fitter.setFunction(func);
369  //fitter.setMaxIter(50+n*10);
370  // Convergence criterium
371  //fitter.setCriteria(0.001);
372
373  // Fit
374  Vector<Float> sigma(x_.nelements());
375  sigma = 1.0;
376
377  parameters_.resize();
378  parameters_ = fitter.fit(x_, y_, sigma, &m_);
379  std::vector<float> ps;
380  parameters_.tovector(ps);
381  setParameters(ps);
382
383  error_.resize();
384  error_ = fitter.errors();
385
386  chisquared_ = fitter.getChi2();
387
388  residual_.resize();
389  residual_ =  y_;
390  fitter.residual(residual_,x_);
391  // use fitter.residual(model=True) to get the model
392  thefit_.resize(x_.nelements());
393  fitter.residual(thefit_,x_,True);
394  return true;
395}
396
[890]397std::vector<float> Fitter::evaluate(int whichComp) const
398{
[517]399  std::vector<float> stlout;
[890]400  uInt idx = uInt(whichComp);
[517]401  Float y;
402  if ( idx < funcs_.nelements() ) {
403    for (uInt i=0; i<x_.nelements(); ++i) {
404      y = (*funcs_[idx])(x_[i]);
405      stlout.push_back(float(y));
406    }
407  }
408  return stlout;
409}
[483]410
Note: See TracBrowser for help on using the repository browser.