source: trunk/src/STFitter.cpp @ 2580

Last change on this file since 2580 was 2580, checked in by ShinnosukeKawakami, 12 years ago

hpc33 merged asap-trunk

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.3 KB
Line 
1//#---------------------------------------------------------------------------
2//# Fitter.cc: A Fitter class for spectra
3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004-2012
5//# ATNF
6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
29//# $Id: STFitter.cpp 2580 2012-06-28 04:22:10Z ShinnosukeKawakami $
30//#---------------------------------------------------------------------------
31#include <casa/aips.h>
32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
34#include <casa/Logging/LogIO.h>
35#include <scimath/Fitting.h>
36#include <scimath/Fitting/LinearFit.h>
37#include <scimath/Functionals/CompiledFunction.h>
38#include <scimath/Functionals/CompoundFunction.h>
39#include <scimath/Functionals/Gaussian1D.h>
40#include <scimath/Functionals/Lorentzian1D.h>
41#include <scimath/Functionals/Sinusoid1D.h>
42#include <scimath/Functionals/Polynomial.h>
43#include <scimath/Mathematics/AutoDiff.h>
44#include <scimath/Mathematics/AutoDiffMath.h>
45#include <scimath/Fitting/NonLinearFitLM.h>
46#include <components/SpectralComponents/SpectralEstimate.h>
47
48#include "STFitter.h"
49
50using namespace asap;
51using namespace casa;
52
53Fitter::Fitter()
54{
55}
56
57Fitter::~Fitter()
58{
59  reset();
60}
61
62void Fitter::clear()
63{
64  for (uInt i=0;i< funcs_.nelements();++i) {
65    delete funcs_[i]; funcs_[i] = 0;
66  }
67  funcs_.resize(0,True);
68  parameters_.resize();
69  fixedpar_.resize();
70  error_.resize();
71  thefit_.resize();
72  estimate_.resize();
73  chisquared_ = 0.0;
74}
75
76void Fitter::reset()
77{
78  clear();
79  x_.resize();
80  y_.resize();
81  m_.resize();
82}
83
84
85bool Fitter::computeEstimate() {
86  if (x_.nelements() == 0 || y_.nelements() == 0)
87    throw (AipsError("No x/y data specified."));
88
89  if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
90    return false;
91  uInt n = funcs_.nelements();
92  SpectralEstimate estimator(n);
93  estimator.setQ(5);
94  Int mn,mx;
95  mn = 0;
96  mx = m_.nelements()-1;
97  for (uInt i=0; i<m_.nelements();++i) {
98    if (m_[i]) {
99      mn = i;
100      break;
101    }
102  }
103  // use Int to suppress compiler warning
104  for (Int j=m_.nelements()-1; j>=0;--j) {
105    if (m_[j]) {
106      mx = j;
107      break;
108    }
109  }
110  //mn = 0+x_.nelements()/10;
111  //mx = x_.nelements()-x_.nelements()/10;
112  estimator.setRegion(mn,mx);
113  //estimator.setWindowing(True);
114  SpectralList listGauss = estimator.estimate(x_, y_);
115  parameters_.resize(n*3);
116  Gaussian1D<Float>* g = 0;
117  for (uInt i=0; i<n;i++) {
118    g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
119    if (g) {
120      const GaussianSpectralElement *gauss =
121        dynamic_cast<const GaussianSpectralElement *>(listGauss[i]) ;
122      (*g)[0] = gauss->getAmpl();
123      (*g)[1] = gauss->getCenter();
124      (*g)[2] = gauss->getFWHM();     
125      /*
126      (*g)[0] = listGauss[i].getAmpl();
127      (*g)[1] = listGauss[i].getCenter();
128      (*g)[2] = listGauss[i].getFWHM();
129      */
130    }
131  }
132  estimate_.resize();
133  listGauss.evaluate(estimate_,x_);
134  return true;
135}
136
137std::vector<float> Fitter::getEstimate() const
138{
139  if (estimate_.nelements() == 0)
140    throw (AipsError("No estimate set."));
141  std::vector<float> stlout;
142  estimate_.tovector(stlout);
143  return stlout;
144}
145
146
147bool Fitter::setExpression(const std::string& expr, int ncomp)
148{
149  clear();
150  if (expr == "gauss") {
151    if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
152    funcs_.resize(ncomp);
153    funcnames_.clear();
154    funccomponents_.clear();
155    for (Int k=0; k<ncomp; ++k) {
156      funcs_[k] = new Gaussian1D<Float>();
157      funcnames_.push_back(expr);
158      funccomponents_.push_back(3);
159    }
160  } else if (expr == "lorentz") {
161    if (ncomp < 1) throw (AipsError("Need at least one lorentzian to fit."));
162    funcs_.resize(ncomp);
163    funcnames_.clear();
164    funccomponents_.clear();
165    for (Int k=0; k<ncomp; ++k) {
166      funcs_[k] = new Lorentzian1D<Float>();
167      funcnames_.push_back(expr);
168      funccomponents_.push_back(3);
169    }
170  } else if (expr == "sinusoid") {
171    if (ncomp < 1) throw (AipsError("Need at least one sinusoid to fit."));
172    funcs_.resize(ncomp);
173    funcnames_.clear();
174    funccomponents_.clear();
175    for (Int k=0; k<ncomp; ++k) {
176      funcs_[k] = new Sinusoid1D<Float>();
177      funcnames_.push_back(expr);
178      funccomponents_.push_back(3);
179    }
180  } else if (expr == "poly") {
181    funcs_.resize(1);
182    funcnames_.clear();
183    funccomponents_.clear();
184    funcs_[0] = new Polynomial<Float>(ncomp);
185      funcnames_.push_back(expr);
186      funccomponents_.push_back(ncomp);
187  } else {
188    LogIO os( LogOrigin( "Fitter", "setExpression()", WHERE ) ) ;
189    os << LogIO::WARN << " compiled functions not yet implemented" << LogIO::POST;
190    //funcs_.resize(1);
191    //funcs_[0] = new CompiledFunction<Float>();
192    //funcs_[0]->setFunction(String(expr));
193    return false;
194  }
195  return true;
196}
197
198bool Fitter::setData(std::vector<float> absc, std::vector<float> spec,
199                       std::vector<bool> mask)
200{
201    x_.resize();
202    y_.resize();
203    m_.resize();
204    // convert std::vector to casa Vector
205    Vector<Float> tmpx(absc);
206    Vector<Float> tmpy(spec);
207    Vector<Bool> tmpm(mask);
208    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
209    x_ = tmpx;
210    y_ = tmpy;
211    m_ = tmpm;
212    return true;
213}
214
215std::vector<float> Fitter::getResidual() const
216{
217    if (residual_.nelements() == 0)
218        throw (AipsError("Function not yet fitted."));
219    std::vector<float> stlout;
220    residual_.tovector(stlout);
221    return stlout;
222}
223
224std::vector<float> Fitter::getFit() const
225{
226    Vector<Float> out = thefit_;
227    std::vector<float> stlout;
228    out.tovector(stlout);
229    return stlout;
230
231}
232
233std::vector<float> Fitter::getErrors() const
234{
235    Vector<Float> out = error_;
236    std::vector<float> stlout;
237    out.tovector(stlout);
238    return stlout;
239}
240
241bool Fitter::setParameters(std::vector<float> params)
242{
243    Vector<Float> tmppar(params);
244    if (funcs_.nelements() == 0)
245        throw (AipsError("Function not yet set."));
246    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
247        throw (AipsError("Number of parameters inconsistent with function."));
248    if (parameters_.nelements() == 0) {
249        parameters_.resize(tmppar.nelements());
250        if (tmppar.nelements() != fixedpar_.nelements()) {
251            fixedpar_.resize(tmppar.nelements());
252            fixedpar_ = False;
253        }
254    }
255    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
256        uInt count = 0;
257        for (uInt j=0; j < funcs_.nelements(); ++j) {
258            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
259                (funcs_[j]->parameters())[i] = tmppar[count];
260                parameters_[count] = tmppar[count];
261                ++count;
262            }
263        }
264    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
265        uInt count = 0;
266        for (uInt j=0; j < funcs_.nelements(); ++j) {
267            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
268                (funcs_[j]->parameters())[i] = tmppar[count];
269                parameters_[count] = tmppar[count];
270                ++count;
271            }
272        }
273    } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
274        uInt count = 0;
275        for (uInt j=0; j < funcs_.nelements(); ++j) {
276            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
277                (funcs_[j]->parameters())[i] = tmppar[count];
278                parameters_[count] = tmppar[count];
279                ++count;
280            }
281        }
282    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
283        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
284            parameters_[i] = tmppar[i];
285            (funcs_[0]->parameters())[i] =  tmppar[i];
286        }
287    }
288    // reset
289    if (params.size() == 0) {
290        parameters_.resize();
291        fixedpar_.resize();
292    }
293    return true;
294}
295
296bool Fitter::setFixedParameters(std::vector<bool> fixed)
297{
298    if (funcs_.nelements() == 0)
299        throw (AipsError("Function not yet set."));
300    if (fixedpar_.nelements() > 0 && fixed.size() != fixedpar_.nelements())
301        throw (AipsError("Number of mask elements inconsistent with function."));
302    if (fixedpar_.nelements() == 0) {
303        fixedpar_.resize(parameters_.nelements());
304        fixedpar_ = False;
305    }
306    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
307        uInt count = 0;
308        for (uInt j=0; j < funcs_.nelements(); ++j) {
309            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
310                funcs_[j]->mask(i) = !fixed[count];
311                fixedpar_[count] = fixed[count];
312                ++count;
313            }
314        }
315    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
316      uInt count = 0;
317        for (uInt j=0; j < funcs_.nelements(); ++j) {
318            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
319                funcs_[j]->mask(i) = !fixed[count];
320                fixedpar_[count] = fixed[count];
321                ++count;
322            }
323        }
324    } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
325      uInt count = 0;
326        for (uInt j=0; j < funcs_.nelements(); ++j) {
327            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
328                funcs_[j]->mask(i) = !fixed[count];
329                fixedpar_[count] = fixed[count];
330                ++count;
331            }
332        }
333    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
334        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
335            fixedpar_[i] = fixed[i];
336            funcs_[0]->mask(i) =  !fixed[i];
337        }
338    }
339    return true;
340}
341
342std::vector<float> Fitter::getParameters() const {
343    Vector<Float> out = parameters_;
344    std::vector<float> stlout;
345    out.tovector(stlout);
346    return stlout;
347}
348
349std::vector<bool> Fitter::getFixedParameters() const {
350  Vector<Bool> out(parameters_.nelements());
351  if (fixedpar_.nelements() == 0) {
352    return std::vector<bool>();
353    //throw (AipsError("No parameter mask set."));
354  } else {
355    out = fixedpar_;
356  }
357  std::vector<bool> stlout;
358  out.tovector(stlout);
359  return stlout;
360}
361
362float Fitter::getChisquared() const {
363    return chisquared_;
364}
365
366bool Fitter::fit() {
367  NonLinearFitLM<Float> fitter;
368  CompoundFunction<Float> func;
369
370  uInt n = funcs_.nelements();
371  for (uInt i=0; i<n; ++i) {
372    func.addFunction(*funcs_[i]);
373  }
374
375  fitter.setFunction(func);
376  fitter.setMaxIter(50+n*10);
377  // Convergence criterium
378  fitter.setCriteria(0.001);
379
380  // Fit
381//   Vector<Float> sigma(x_.nelements());
382//   sigma = 1.0;
383
384  parameters_.resize();
385//   parameters_ = fitter.fit(x_, y_, sigma, &m_);
386  parameters_ = fitter.fit(x_, y_, &m_); 
387  if ( !fitter.converged() ) {
388     return false;
389  }
390  std::vector<float> ps;
391  parameters_.tovector(ps);
392  setParameters(ps);
393
394  error_.resize();
395  error_ = fitter.errors();
396
397  chisquared_ = fitter.getChi2();
398
399//   residual_.resize();
400//   residual_ =  y_;
401//   fitter.residual(residual_,x_);
402  // use fitter.residual(model=True) to get the model
403  thefit_.resize(x_.nelements());
404  fitter.residual(thefit_,x_,True);
405  // residual = data - model
406  residual_.resize(x_.nelements());
407  residual_ = y_ - thefit_ ;
408  return true;
409}
410
411bool Fitter::lfit() {
412  LinearFit<Float> fitter;
413  CompoundFunction<Float> func;
414
415  uInt n = funcs_.nelements();
416  for (uInt i=0; i<n; ++i) {
417    func.addFunction(*funcs_[i]);
418  }
419
420  fitter.setFunction(func);
421  //fitter.setMaxIter(50+n*10);
422  // Convergence criterium
423  //fitter.setCriteria(0.001);
424
425  // Fit
426//   Vector<Float> sigma(x_.nelements());
427//   sigma = 1.0;
428
429  parameters_.resize();
430//   parameters_ = fitter.fit(x_, y_, sigma, &m_);
431  parameters_ = fitter.fit(x_, y_, &m_);
432  std::vector<float> ps;
433  parameters_.tovector(ps);
434  setParameters(ps);
435
436  error_.resize();
437  error_ = fitter.errors();
438
439  chisquared_ = fitter.getChi2();
440
441//   residual_.resize();
442//   residual_ =  y_;
443//   fitter.residual(residual_,x_);
444  // use fitter.residual(model=True) to get the model
445  thefit_.resize(x_.nelements());
446  fitter.residual(thefit_,x_,True);
447  // residual = data - model
448  residual_.resize(x_.nelements());
449  residual_ = y_ - thefit_ ;
450  return true;
451}
452
453std::vector<float> Fitter::evaluate(int whichComp) const
454{
455  std::vector<float> stlout;
456  uInt idx = uInt(whichComp);
457  Float y;
458  if ( idx < funcs_.nelements() ) {
459    for (uInt i=0; i<x_.nelements(); ++i) {
460      y = (*funcs_[idx])(x_[i]);
461      stlout.push_back(float(y));
462    }
463  }
464  return stlout;
465}
466
467STFitEntry Fitter::getFitEntry() const
468{
469  STFitEntry fit;
470  fit.setParameters(getParameters());
471  fit.setErrors(getErrors());
472  fit.setComponents(funccomponents_);
473  fit.setFunctions(funcnames_);
474  fit.setParmasks(getFixedParameters());
475  return fit;
476}
Note: See TracBrowser for help on using the repository browser.