source: branches/mergetest/src/STFitter.cpp @ 1779

Last change on this file since 1779 was 1779, checked in by Kana Sugimoto, 14 years ago

New Development: Yes

JIRA Issue: No (test merging alma branch)

Ready for Test: Yes

Interface Changes: Yes

What Interface Changed:

Test Programs:

Put in Release Notes: No

Module(s):

Description:


  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 11.1 KB
Line 
1//#---------------------------------------------------------------------------
2//# Fitter.cc: A Fitter class for spectra
3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004
5//# ATNF
6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
29//# $Id: STFitter.cpp 1779 2010-07-29 09:13:46Z KanaSugimoto $
30//#---------------------------------------------------------------------------
31#include <casa/aips.h>
32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
34#include <casa/Logging/LogIO.h>
35#include <scimath/Fitting.h>
36#include <scimath/Fitting/LinearFit.h>
37#include <scimath/Functionals/CompiledFunction.h>
38#include <scimath/Functionals/CompoundFunction.h>
39#include <scimath/Functionals/Gaussian1D.h>
40#include "Lorentzian1D.h"
41#include <scimath/Functionals/Polynomial.h>
42#include <scimath/Mathematics/AutoDiff.h>
43#include <scimath/Mathematics/AutoDiffMath.h>
44#include <scimath/Fitting/NonLinearFitLM.h>
45#include <components/SpectralComponents/SpectralEstimate.h>
46
47#include "STFitter.h"
48
49using namespace asap;
50using namespace casa;
51
52Fitter::Fitter()
53{
54}
55
56Fitter::~Fitter()
57{
58  reset();
59}
60
61void Fitter::clear()
62{
63  for (uInt i=0;i< funcs_.nelements();++i) {
64    delete funcs_[i]; funcs_[i] = 0;
65  }
66  funcs_.resize(0,True);
67  parameters_.resize();
68  fixedpar_.resize();
69  error_.resize();
70  thefit_.resize();
71  estimate_.resize();
72  chisquared_ = 0.0;
73}
74
75void Fitter::reset()
76{
77  clear();
78  x_.resize();
79  y_.resize();
80  m_.resize();
81}
82
83
84bool Fitter::computeEstimate() {
85  if (x_.nelements() == 0 || y_.nelements() == 0)
86    throw (AipsError("No x/y data specified."));
87
88  if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
89    return false;
90  uInt n = funcs_.nelements();
91  SpectralEstimate estimator(n);
92  estimator.setQ(5);
93  Int mn,mx;
94  mn = 0;
95  mx = m_.nelements()-1;
96  for (uInt i=0; i<m_.nelements();++i) {
97    if (m_[i]) {
98      mn = i;
99      break;
100    }
101  }
102  for (uInt j=m_.nelements()-1; j>=0;--j) {
103    if (m_[j]) {
104      mx = j;
105      break;
106    }
107  }
108  //mn = 0+x_.nelements()/10;
109  //mx = x_.nelements()-x_.nelements()/10;
110  estimator.setRegion(mn,mx);
111  //estimator.setWindowing(True);
112  SpectralList listGauss = estimator.estimate(x_, y_);
113  parameters_.resize(n*3);
114  Gaussian1D<Float>* g = 0;
115  for (uInt i=0; i<n;i++) {
116    g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
117    if (g) {
118      (*g)[0] = listGauss[i].getAmpl();
119      (*g)[1] = listGauss[i].getCenter();
120      (*g)[2] = listGauss[i].getFWHM();
121    }
122  }
123  estimate_.resize();
124  listGauss.evaluate(estimate_,x_);
125  return true;
126}
127
128std::vector<float> Fitter::getEstimate() const
129{
130  if (estimate_.nelements() == 0)
131    throw (AipsError("No estimate set."));
132  std::vector<float> stlout;
133  estimate_.tovector(stlout);
134  return stlout;
135}
136
137
138bool Fitter::setExpression(const std::string& expr, int ncomp)
139{
140  clear();
141  if (expr == "gauss") {
142    if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
143    funcs_.resize(ncomp);
144    for (Int k=0; k<ncomp; ++k) {
145      funcs_[k] = new Gaussian1D<Float>();
146    }
147  } else if (expr == "poly") {
148    funcs_.resize(1);
149    funcs_[0] = new Polynomial<Float>(ncomp);
150  } else if (expr == "lorentz") {
151    if (ncomp < 1) throw (AipsError("Need at least one lorentzian to fit."));
152    funcs_.resize(ncomp);
153    for (Int k=0; k<ncomp; ++k) {
154      funcs_[k] = new Lorentzian1D<Float>();
155    }
156  } else {
157    //cerr << " compiled functions not yet implemented" << endl;
158    LogIO os( LogOrigin( "Fitter", "setExpression()", WHERE ) ) ;
159    os << LogIO::WARN << " compiled functions not yet implemented" << LogIO::POST;
160    //funcs_.resize(1);
161    //funcs_[0] = new CompiledFunction<Float>();
162    //funcs_[0]->setFunction(String(expr));
163    return false;
164  }
165  return true;
166}
167
168bool Fitter::setData(std::vector<float> absc, std::vector<float> spec,
169                       std::vector<bool> mask)
170{
171    x_.resize();
172    y_.resize();
173    m_.resize();
174    // convert std::vector to casa Vector
175    Vector<Float> tmpx(absc);
176    Vector<Float> tmpy(spec);
177    Vector<Bool> tmpm(mask);
178    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
179    x_ = tmpx;
180    y_ = tmpy;
181    m_ = tmpm;
182    return true;
183}
184
185std::vector<float> Fitter::getResidual() const
186{
187    if (residual_.nelements() == 0)
188        throw (AipsError("Function not yet fitted."));
189    std::vector<float> stlout;
190    residual_.tovector(stlout);
191    return stlout;
192}
193
194std::vector<float> Fitter::getFit() const
195{
196    Vector<Float> out = thefit_;
197    std::vector<float> stlout;
198    out.tovector(stlout);
199    return stlout;
200
201}
202
203std::vector<float> Fitter::getErrors() const
204{
205    Vector<Float> out = error_;
206    std::vector<float> stlout;
207    out.tovector(stlout);
208    return stlout;
209}
210
211bool Fitter::setParameters(std::vector<float> params)
212{
213    Vector<Float> tmppar(params);
214    if (funcs_.nelements() == 0)
215        throw (AipsError("Function not yet set."));
216    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
217        throw (AipsError("Number of parameters inconsistent with function."));
218    if (parameters_.nelements() == 0) {
219        parameters_.resize(tmppar.nelements());
220        if (tmppar.nelements() != fixedpar_.nelements()) {
221            fixedpar_.resize(tmppar.nelements());
222            fixedpar_ = False;
223        }
224    }
225    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
226        uInt count = 0;
227        for (uInt j=0; j < funcs_.nelements(); ++j) {
228            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
229                (funcs_[j]->parameters())[i] = tmppar[count];
230                parameters_[count] = tmppar[count];
231                ++count;
232            }
233        }
234    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
235        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
236            parameters_[i] = tmppar[i];
237            (funcs_[0]->parameters())[i] =  tmppar[i];
238        }
239    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
240        uInt count = 0;
241        for (uInt j=0; j < funcs_.nelements(); ++j) {
242            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
243                (funcs_[j]->parameters())[i] = tmppar[count];
244                parameters_[count] = tmppar[count];
245                ++count;
246            }
247        }
248    }
249    // reset
250    if (params.size() == 0) {
251        parameters_.resize();
252        fixedpar_.resize();
253    }
254    return true;
255}
256
257bool Fitter::setFixedParameters(std::vector<bool> fixed)
258{
259    if (funcs_.nelements() == 0)
260        throw (AipsError("Function not yet set."));
261    if (fixedpar_.nelements() > 0 && fixed.size() != fixedpar_.nelements())
262        throw (AipsError("Number of mask elements inconsistent with function."));
263    if (fixedpar_.nelements() == 0) {
264        fixedpar_.resize(parameters_.nelements());
265        fixedpar_ = False;
266    }
267    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
268        uInt count = 0;
269        for (uInt j=0; j < funcs_.nelements(); ++j) {
270            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
271                funcs_[j]->mask(i) = !fixed[count];
272                fixedpar_[count] = fixed[count];
273                ++count;
274            }
275        }
276    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
277        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
278            fixedpar_[i] = fixed[i];
279            funcs_[0]->mask(i) =  !fixed[i];
280        }
281    } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
282      uInt count = 0;
283        for (uInt j=0; j < funcs_.nelements(); ++j) {
284            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
285                funcs_[j]->mask(i) = !fixed[count];
286                fixedpar_[count] = fixed[count];
287                ++count;
288            }
289        }
290    }
291    return true;
292}
293
294std::vector<float> Fitter::getParameters() const {
295    Vector<Float> out = parameters_;
296    std::vector<float> stlout;
297    out.tovector(stlout);
298    return stlout;
299}
300
301std::vector<bool> Fitter::getFixedParameters() const {
302  Vector<Bool> out(parameters_.nelements());
303  if (fixedpar_.nelements() == 0) {
304    return std::vector<bool>();
305    //throw (AipsError("No parameter mask set."));
306  } else {
307    out = fixedpar_;
308  }
309  std::vector<bool> stlout;
310  out.tovector(stlout);
311  return stlout;
312}
313
314float Fitter::getChisquared() const {
315    return chisquared_;
316}
317
318bool Fitter::fit() {
319  NonLinearFitLM<Float> fitter;
320  CompoundFunction<Float> func;
321
322  uInt n = funcs_.nelements();
323  for (uInt i=0; i<n; ++i) {
324    func.addFunction(*funcs_[i]);
325  }
326
327  fitter.setFunction(func);
328  fitter.setMaxIter(50+n*10);
329  // Convergence criterium
330  fitter.setCriteria(0.001);
331
332  // Fit
333  Vector<Float> sigma(x_.nelements());
334  sigma = 1.0;
335
336  parameters_.resize();
337  parameters_ = fitter.fit(x_, y_, sigma, &m_);
338  if ( !fitter.converged() ) {
339     return false;
340  }
341  std::vector<float> ps;
342  parameters_.tovector(ps);
343  setParameters(ps);
344
345  error_.resize();
346  error_ = fitter.errors();
347
348  chisquared_ = fitter.getChi2();
349
350  residual_.resize();
351  residual_ =  y_;
352  fitter.residual(residual_,x_);
353  // use fitter.residual(model=True) to get the model
354  thefit_.resize(x_.nelements());
355  fitter.residual(thefit_,x_,True);
356  return true;
357}
358
359bool Fitter::lfit() {
360  LinearFit<Float> fitter;
361  CompoundFunction<Float> func;
362
363  uInt n = funcs_.nelements();
364  for (uInt i=0; i<n; ++i) {
365    func.addFunction(*funcs_[i]);
366  }
367
368  fitter.setFunction(func);
369  //fitter.setMaxIter(50+n*10);
370  // Convergence criterium
371  //fitter.setCriteria(0.001);
372
373  // Fit
374  Vector<Float> sigma(x_.nelements());
375  sigma = 1.0;
376
377  parameters_.resize();
378  parameters_ = fitter.fit(x_, y_, sigma, &m_);
379  std::vector<float> ps;
380  parameters_.tovector(ps);
381  setParameters(ps);
382
383  error_.resize();
384  error_ = fitter.errors();
385
386  chisquared_ = fitter.getChi2();
387
388  residual_.resize();
389  residual_ =  y_;
390  fitter.residual(residual_,x_);
391  // use fitter.residual(model=True) to get the model
392  thefit_.resize(x_.nelements());
393  fitter.residual(thefit_,x_,True);
394  return true;
395}
396
397std::vector<float> Fitter::evaluate(int whichComp) const
398{
399  std::vector<float> stlout;
400  uInt idx = uInt(whichComp);
401  Float y;
402  if ( idx < funcs_.nelements() ) {
403    for (uInt i=0; i<x_.nelements(); ++i) {
404      y = (*funcs_[idx])(x_[i]);
405      stlout.push_back(float(y));
406    }
407  }
408  return stlout;
409}
410
Note: See TracBrowser for help on using the repository browser.