source: branches/casa-prerelease/pre-asap/src/STFitter.cpp @ 2342

Last change on this file since 2342 was 2163, checked in by Malte Marquarding, 13 years ago

Remove various compiler warnings

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 12.8 KB
Line 
1//#---------------------------------------------------------------------------
2//# Fitter.cc: A Fitter class for spectra
3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004
5//# ATNF
6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
29//# $Id: STFitter.cpp 2163 2011-05-10 05:02:56Z MalteMarquarding $
30//#---------------------------------------------------------------------------
31#include <casa/aips.h>
32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
34#include <casa/Logging/LogIO.h>
35#include <scimath/Fitting.h>
36#include <scimath/Fitting/LinearFit.h>
37#include <scimath/Functionals/CompiledFunction.h>
38#include <scimath/Functionals/CompoundFunction.h>
39#include <scimath/Functionals/Gaussian1D.h>
40#include "Lorentzian1D.h"
41#include <scimath/Functionals/Sinusoid1D.h>
42#include <scimath/Functionals/Polynomial.h>
43#include <scimath/Mathematics/AutoDiff.h>
44#include <scimath/Mathematics/AutoDiffMath.h>
45#include <scimath/Fitting/NonLinearFitLM.h>
46#include <components/SpectralComponents/SpectralEstimate.h>
47
48#include "STFitter.h"
49
50using namespace asap;
51using namespace casa;
52
53Fitter::Fitter()
54{
55}
56
57Fitter::~Fitter()
58{
59  reset();
60}
61
62void Fitter::clear()
63{
64  for (uInt i=0;i< funcs_.nelements();++i) {
65    delete funcs_[i]; funcs_[i] = 0;
66  }
67  funcs_.resize(0,True);
68  parameters_.resize();
69  fixedpar_.resize();
70  error_.resize();
71  thefit_.resize();
72  estimate_.resize();
73  chisquared_ = 0.0;
74}
75
76void Fitter::reset()
77{
78  clear();
79  x_.resize();
80  y_.resize();
81  m_.resize();
82}
83
84
85bool Fitter::computeEstimate() {
86  if (x_.nelements() == 0 || y_.nelements() == 0)
87    throw (AipsError("No x/y data specified."));
88
89  if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
90    return false;
91  uInt n = funcs_.nelements();
92  SpectralEstimate estimator(n);
93  estimator.setQ(5);
94  Int mn,mx;
95  mn = 0;
96  mx = m_.nelements()-1;
97  for (uInt i=0; i<m_.nelements();++i) {
98    if (m_[i]) {
99      mn = i;
100      break;
101    }
102  }
103  // use Int to suppress compiler warning
104  for (Int j=m_.nelements()-1; j>=0;--j) {
105    if (m_[j]) {
106      mx = j;
107      break;
108    }
109  }
110  //mn = 0+x_.nelements()/10;
111  //mx = x_.nelements()-x_.nelements()/10;
112  estimator.setRegion(mn,mx);
113  //estimator.setWindowing(True);
114  SpectralList listGauss = estimator.estimate(x_, y_);
115  parameters_.resize(n*3);
116  Gaussian1D<Float>* g = 0;
117  for (uInt i=0; i<n;i++) {
118    g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
119    if (g) {
120      (*g)[0] = listGauss[i].getAmpl();
121      (*g)[1] = listGauss[i].getCenter();
122      (*g)[2] = listGauss[i].getFWHM();
123    }
124  }
125  estimate_.resize();
126  listGauss.evaluate(estimate_,x_);
127  return true;
128}
129
130std::vector<float> Fitter::getEstimate() const
131{
132  if (estimate_.nelements() == 0)
133    throw (AipsError("No estimate set."));
134  std::vector<float> stlout;
135  estimate_.tovector(stlout);
136  return stlout;
137}
138
139
140bool Fitter::setExpression(const std::string& expr, int ncomp)
141{
142  clear();
143  if (expr == "gauss") {
144    if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
145    funcs_.resize(ncomp);
146    funcnames_.clear();
147    funccomponents_.clear();
148    for (Int k=0; k<ncomp; ++k) {
149      funcs_[k] = new Gaussian1D<Float>();
150      funcnames_.push_back(expr);
151      funccomponents_.push_back(3);
152    }
153  } else if (expr == "lorentz") {
154    if (ncomp < 1) throw (AipsError("Need at least one lorentzian to fit."));
155    funcs_.resize(ncomp);
156    funcnames_.clear();
157    funccomponents_.clear();
158    for (Int k=0; k<ncomp; ++k) {
159      funcs_[k] = new Lorentzian1D<Float>();
160      funcnames_.push_back(expr);
161      funccomponents_.push_back(3);
162    }
163  } else if (expr == "sinusoid") {
164    if (ncomp < 1) throw (AipsError("Need at least one sinusoid to fit."));
165    funcs_.resize(ncomp);
166    funcnames_.clear();
167    funccomponents_.clear();
168    for (Int k=0; k<ncomp; ++k) {
169      funcs_[k] = new Sinusoid1D<Float>();
170      funcnames_.push_back(expr);
171      funccomponents_.push_back(3);
172    }
173  } else if (expr == "poly") {
174    funcs_.resize(1);
175    funcnames_.clear();
176    funccomponents_.clear();
177    funcs_[0] = new Polynomial<Float>(ncomp);
178      funcnames_.push_back(expr);
179      funccomponents_.push_back(ncomp);
180  } else {
181    LogIO os( LogOrigin( "Fitter", "setExpression()", WHERE ) ) ;
182    os << LogIO::WARN << " compiled functions not yet implemented" << LogIO::POST;
183    //funcs_.resize(1);
184    //funcs_[0] = new CompiledFunction<Float>();
185    //funcs_[0]->setFunction(String(expr));
186    return false;
187  }
188  return true;
189}
190
191bool Fitter::setData(std::vector<float> absc, std::vector<float> spec,
192                       std::vector<bool> mask)
193{
194    x_.resize();
195    y_.resize();
196    m_.resize();
197    // convert std::vector to casa Vector
198    Vector<Float> tmpx(absc);
199    Vector<Float> tmpy(spec);
200    Vector<Bool> tmpm(mask);
201    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
202    x_ = tmpx;
203    y_ = tmpy;
204    m_ = tmpm;
205    return true;
206}
207
208std::vector<float> Fitter::getResidual() const
209{
210    if (residual_.nelements() == 0)
211        throw (AipsError("Function not yet fitted."));
212    std::vector<float> stlout;
213    residual_.tovector(stlout);
214    return stlout;
215}
216
217std::vector<float> Fitter::getFit() const
218{
219    Vector<Float> out = thefit_;
220    std::vector<float> stlout;
221    out.tovector(stlout);
222    return stlout;
223
224}
225
226std::vector<float> Fitter::getErrors() const
227{
228    Vector<Float> out = error_;
229    std::vector<float> stlout;
230    out.tovector(stlout);
231    return stlout;
232}
233
234bool Fitter::setParameters(std::vector<float> params)
235{
236    Vector<Float> tmppar(params);
237    if (funcs_.nelements() == 0)
238        throw (AipsError("Function not yet set."));
239    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
240        throw (AipsError("Number of parameters inconsistent with function."));
241    if (parameters_.nelements() == 0) {
242        parameters_.resize(tmppar.nelements());
243        if (tmppar.nelements() != fixedpar_.nelements()) {
244            fixedpar_.resize(tmppar.nelements());
245            fixedpar_ = False;
246        }
247    }
248    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
249        uInt count = 0;
250        for (uInt j=0; j < funcs_.nelements(); ++j) {
251            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
252                (funcs_[j]->parameters())[i] = tmppar[count];
253                parameters_[count] = tmppar[count];
254                ++count;
255            }
256        }
257    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
258        uInt count = 0;
259        for (uInt j=0; j < funcs_.nelements(); ++j) {
260            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
261                (funcs_[j]->parameters())[i] = tmppar[count];
262                parameters_[count] = tmppar[count];
263                ++count;
264            }
265        }
266    } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
267        uInt count = 0;
268        for (uInt j=0; j < funcs_.nelements(); ++j) {
269            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
270                (funcs_[j]->parameters())[i] = tmppar[count];
271                parameters_[count] = tmppar[count];
272                ++count;
273            }
274        }
275    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
276        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
277            parameters_[i] = tmppar[i];
278            (funcs_[0]->parameters())[i] =  tmppar[i];
279        }
280    }
281    // reset
282    if (params.size() == 0) {
283        parameters_.resize();
284        fixedpar_.resize();
285    }
286    return true;
287}
288
289bool Fitter::setFixedParameters(std::vector<bool> fixed)
290{
291    if (funcs_.nelements() == 0)
292        throw (AipsError("Function not yet set."));
293    if (fixedpar_.nelements() > 0 && fixed.size() != fixedpar_.nelements())
294        throw (AipsError("Number of mask elements inconsistent with function."));
295    if (fixedpar_.nelements() == 0) {
296        fixedpar_.resize(parameters_.nelements());
297        fixedpar_ = False;
298    }
299    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
300        uInt count = 0;
301        for (uInt j=0; j < funcs_.nelements(); ++j) {
302            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
303                funcs_[j]->mask(i) = !fixed[count];
304                fixedpar_[count] = fixed[count];
305                ++count;
306            }
307        }
308    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
309      uInt count = 0;
310        for (uInt j=0; j < funcs_.nelements(); ++j) {
311            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
312                funcs_[j]->mask(i) = !fixed[count];
313                fixedpar_[count] = fixed[count];
314                ++count;
315            }
316        }
317    } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
318      uInt count = 0;
319        for (uInt j=0; j < funcs_.nelements(); ++j) {
320            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
321                funcs_[j]->mask(i) = !fixed[count];
322                fixedpar_[count] = fixed[count];
323                ++count;
324            }
325        }
326    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
327        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
328            fixedpar_[i] = fixed[i];
329            funcs_[0]->mask(i) =  !fixed[i];
330        }
331    }
332    return true;
333}
334
335std::vector<float> Fitter::getParameters() const {
336    Vector<Float> out = parameters_;
337    std::vector<float> stlout;
338    out.tovector(stlout);
339    return stlout;
340}
341
342std::vector<bool> Fitter::getFixedParameters() const {
343  Vector<Bool> out(parameters_.nelements());
344  if (fixedpar_.nelements() == 0) {
345    return std::vector<bool>();
346    //throw (AipsError("No parameter mask set."));
347  } else {
348    out = fixedpar_;
349  }
350  std::vector<bool> stlout;
351  out.tovector(stlout);
352  return stlout;
353}
354
355float Fitter::getChisquared() const {
356    return chisquared_;
357}
358
359bool Fitter::fit() {
360  NonLinearFitLM<Float> fitter;
361  CompoundFunction<Float> func;
362
363  uInt n = funcs_.nelements();
364  for (uInt i=0; i<n; ++i) {
365    func.addFunction(*funcs_[i]);
366  }
367
368  fitter.setFunction(func);
369  fitter.setMaxIter(50+n*10);
370  // Convergence criterium
371  fitter.setCriteria(0.001);
372
373  // Fit
374  Vector<Float> sigma(x_.nelements());
375  sigma = 1.0;
376
377  parameters_.resize();
378  parameters_ = fitter.fit(x_, y_, sigma, &m_);
379  if ( !fitter.converged() ) {
380     return false;
381  }
382  std::vector<float> ps;
383  parameters_.tovector(ps);
384  setParameters(ps);
385
386  error_.resize();
387  error_ = fitter.errors();
388
389  chisquared_ = fitter.getChi2();
390
391  residual_.resize();
392  residual_ =  y_;
393  fitter.residual(residual_,x_);
394  // use fitter.residual(model=True) to get the model
395  thefit_.resize(x_.nelements());
396  fitter.residual(thefit_,x_,True);
397  return true;
398}
399
400bool Fitter::lfit() {
401  LinearFit<Float> fitter;
402  CompoundFunction<Float> func;
403
404  uInt n = funcs_.nelements();
405  for (uInt i=0; i<n; ++i) {
406    func.addFunction(*funcs_[i]);
407  }
408
409  fitter.setFunction(func);
410  //fitter.setMaxIter(50+n*10);
411  // Convergence criterium
412  //fitter.setCriteria(0.001);
413
414  // Fit
415  Vector<Float> sigma(x_.nelements());
416  sigma = 1.0;
417
418  parameters_.resize();
419  parameters_ = fitter.fit(x_, y_, sigma, &m_);
420  std::vector<float> ps;
421  parameters_.tovector(ps);
422  setParameters(ps);
423
424  error_.resize();
425  error_ = fitter.errors();
426
427  chisquared_ = fitter.getChi2();
428
429  residual_.resize();
430  residual_ =  y_;
431  fitter.residual(residual_,x_);
432  // use fitter.residual(model=True) to get the model
433  thefit_.resize(x_.nelements());
434  fitter.residual(thefit_,x_,True);
435  return true;
436}
437
438std::vector<float> Fitter::evaluate(int whichComp) const
439{
440  std::vector<float> stlout;
441  uInt idx = uInt(whichComp);
442  Float y;
443  if ( idx < funcs_.nelements() ) {
444    for (uInt i=0; i<x_.nelements(); ++i) {
445      y = (*funcs_[idx])(x_[i]);
446      stlout.push_back(float(y));
447    }
448  }
449  return stlout;
450}
451
452STFitEntry Fitter::getFitEntry() const
453{
454  STFitEntry fit;
455  fit.setParameters(getParameters());
456  fit.setErrors(getErrors());
457  fit.setComponents(funccomponents_);
458  fit.setFunctions(funcnames_);
459  fit.setParmasks(getFixedParameters());
460  return fit;
461}
Note: See TracBrowser for help on using the repository browser.