source: trunk/src/STFitter.cpp@ 2197

Last change on this file since 2197 was 2163, checked in by Malte Marquarding, 14 years ago

Remove various compiler warnings

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 12.8 KB
Line 
1//#---------------------------------------------------------------------------
2//# Fitter.cc: A Fitter class for spectra
3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004
5//# ATNF
6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//# Internet email: Malte.Marquarding@csiro.au
23//# Postal address: Malte Marquarding,
24//# Australia Telescope National Facility,
25//# P.O. Box 76,
26//# Epping, NSW, 2121,
27//# AUSTRALIA
28//#
29//# $Id: STFitter.cpp 2163 2011-05-10 05:02:56Z MalteMarquarding $
30//#---------------------------------------------------------------------------
31#include <casa/aips.h>
32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
34#include <casa/Logging/LogIO.h>
35#include <scimath/Fitting.h>
36#include <scimath/Fitting/LinearFit.h>
37#include <scimath/Functionals/CompiledFunction.h>
38#include <scimath/Functionals/CompoundFunction.h>
39#include <scimath/Functionals/Gaussian1D.h>
40#include "Lorentzian1D.h"
41#include <scimath/Functionals/Sinusoid1D.h>
42#include <scimath/Functionals/Polynomial.h>
43#include <scimath/Mathematics/AutoDiff.h>
44#include <scimath/Mathematics/AutoDiffMath.h>
45#include <scimath/Fitting/NonLinearFitLM.h>
46#include <components/SpectralComponents/SpectralEstimate.h>
47
48#include "STFitter.h"
49
50using namespace asap;
51using namespace casa;
52
53Fitter::Fitter()
54{
55}
56
57Fitter::~Fitter()
58{
59 reset();
60}
61
62void Fitter::clear()
63{
64 for (uInt i=0;i< funcs_.nelements();++i) {
65 delete funcs_[i]; funcs_[i] = 0;
66 }
67 funcs_.resize(0,True);
68 parameters_.resize();
69 fixedpar_.resize();
70 error_.resize();
71 thefit_.resize();
72 estimate_.resize();
73 chisquared_ = 0.0;
74}
75
76void Fitter::reset()
77{
78 clear();
79 x_.resize();
80 y_.resize();
81 m_.resize();
82}
83
84
85bool Fitter::computeEstimate() {
86 if (x_.nelements() == 0 || y_.nelements() == 0)
87 throw (AipsError("No x/y data specified."));
88
89 if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
90 return false;
91 uInt n = funcs_.nelements();
92 SpectralEstimate estimator(n);
93 estimator.setQ(5);
94 Int mn,mx;
95 mn = 0;
96 mx = m_.nelements()-1;
97 for (uInt i=0; i<m_.nelements();++i) {
98 if (m_[i]) {
99 mn = i;
100 break;
101 }
102 }
103 // use Int to suppress compiler warning
104 for (Int j=m_.nelements()-1; j>=0;--j) {
105 if (m_[j]) {
106 mx = j;
107 break;
108 }
109 }
110 //mn = 0+x_.nelements()/10;
111 //mx = x_.nelements()-x_.nelements()/10;
112 estimator.setRegion(mn,mx);
113 //estimator.setWindowing(True);
114 SpectralList listGauss = estimator.estimate(x_, y_);
115 parameters_.resize(n*3);
116 Gaussian1D<Float>* g = 0;
117 for (uInt i=0; i<n;i++) {
118 g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
119 if (g) {
120 (*g)[0] = listGauss[i].getAmpl();
121 (*g)[1] = listGauss[i].getCenter();
122 (*g)[2] = listGauss[i].getFWHM();
123 }
124 }
125 estimate_.resize();
126 listGauss.evaluate(estimate_,x_);
127 return true;
128}
129
130std::vector<float> Fitter::getEstimate() const
131{
132 if (estimate_.nelements() == 0)
133 throw (AipsError("No estimate set."));
134 std::vector<float> stlout;
135 estimate_.tovector(stlout);
136 return stlout;
137}
138
139
140bool Fitter::setExpression(const std::string& expr, int ncomp)
141{
142 clear();
143 if (expr == "gauss") {
144 if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
145 funcs_.resize(ncomp);
146 funcnames_.clear();
147 funccomponents_.clear();
148 for (Int k=0; k<ncomp; ++k) {
149 funcs_[k] = new Gaussian1D<Float>();
150 funcnames_.push_back(expr);
151 funccomponents_.push_back(3);
152 }
153 } else if (expr == "lorentz") {
154 if (ncomp < 1) throw (AipsError("Need at least one lorentzian to fit."));
155 funcs_.resize(ncomp);
156 funcnames_.clear();
157 funccomponents_.clear();
158 for (Int k=0; k<ncomp; ++k) {
159 funcs_[k] = new Lorentzian1D<Float>();
160 funcnames_.push_back(expr);
161 funccomponents_.push_back(3);
162 }
163 } else if (expr == "sinusoid") {
164 if (ncomp < 1) throw (AipsError("Need at least one sinusoid to fit."));
165 funcs_.resize(ncomp);
166 funcnames_.clear();
167 funccomponents_.clear();
168 for (Int k=0; k<ncomp; ++k) {
169 funcs_[k] = new Sinusoid1D<Float>();
170 funcnames_.push_back(expr);
171 funccomponents_.push_back(3);
172 }
173 } else if (expr == "poly") {
174 funcs_.resize(1);
175 funcnames_.clear();
176 funccomponents_.clear();
177 funcs_[0] = new Polynomial<Float>(ncomp);
178 funcnames_.push_back(expr);
179 funccomponents_.push_back(ncomp);
180 } else {
181 LogIO os( LogOrigin( "Fitter", "setExpression()", WHERE ) ) ;
182 os << LogIO::WARN << " compiled functions not yet implemented" << LogIO::POST;
183 //funcs_.resize(1);
184 //funcs_[0] = new CompiledFunction<Float>();
185 //funcs_[0]->setFunction(String(expr));
186 return false;
187 }
188 return true;
189}
190
191bool Fitter::setData(std::vector<float> absc, std::vector<float> spec,
192 std::vector<bool> mask)
193{
194 x_.resize();
195 y_.resize();
196 m_.resize();
197 // convert std::vector to casa Vector
198 Vector<Float> tmpx(absc);
199 Vector<Float> tmpy(spec);
200 Vector<Bool> tmpm(mask);
201 AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
202 x_ = tmpx;
203 y_ = tmpy;
204 m_ = tmpm;
205 return true;
206}
207
208std::vector<float> Fitter::getResidual() const
209{
210 if (residual_.nelements() == 0)
211 throw (AipsError("Function not yet fitted."));
212 std::vector<float> stlout;
213 residual_.tovector(stlout);
214 return stlout;
215}
216
217std::vector<float> Fitter::getFit() const
218{
219 Vector<Float> out = thefit_;
220 std::vector<float> stlout;
221 out.tovector(stlout);
222 return stlout;
223
224}
225
226std::vector<float> Fitter::getErrors() const
227{
228 Vector<Float> out = error_;
229 std::vector<float> stlout;
230 out.tovector(stlout);
231 return stlout;
232}
233
234bool Fitter::setParameters(std::vector<float> params)
235{
236 Vector<Float> tmppar(params);
237 if (funcs_.nelements() == 0)
238 throw (AipsError("Function not yet set."));
239 if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
240 throw (AipsError("Number of parameters inconsistent with function."));
241 if (parameters_.nelements() == 0) {
242 parameters_.resize(tmppar.nelements());
243 if (tmppar.nelements() != fixedpar_.nelements()) {
244 fixedpar_.resize(tmppar.nelements());
245 fixedpar_ = False;
246 }
247 }
248 if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
249 uInt count = 0;
250 for (uInt j=0; j < funcs_.nelements(); ++j) {
251 for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
252 (funcs_[j]->parameters())[i] = tmppar[count];
253 parameters_[count] = tmppar[count];
254 ++count;
255 }
256 }
257 } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
258 uInt count = 0;
259 for (uInt j=0; j < funcs_.nelements(); ++j) {
260 for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
261 (funcs_[j]->parameters())[i] = tmppar[count];
262 parameters_[count] = tmppar[count];
263 ++count;
264 }
265 }
266 } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
267 uInt count = 0;
268 for (uInt j=0; j < funcs_.nelements(); ++j) {
269 for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
270 (funcs_[j]->parameters())[i] = tmppar[count];
271 parameters_[count] = tmppar[count];
272 ++count;
273 }
274 }
275 } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
276 for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
277 parameters_[i] = tmppar[i];
278 (funcs_[0]->parameters())[i] = tmppar[i];
279 }
280 }
281 // reset
282 if (params.size() == 0) {
283 parameters_.resize();
284 fixedpar_.resize();
285 }
286 return true;
287}
288
289bool Fitter::setFixedParameters(std::vector<bool> fixed)
290{
291 if (funcs_.nelements() == 0)
292 throw (AipsError("Function not yet set."));
293 if (fixedpar_.nelements() > 0 && fixed.size() != fixedpar_.nelements())
294 throw (AipsError("Number of mask elements inconsistent with function."));
295 if (fixedpar_.nelements() == 0) {
296 fixedpar_.resize(parameters_.nelements());
297 fixedpar_ = False;
298 }
299 if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
300 uInt count = 0;
301 for (uInt j=0; j < funcs_.nelements(); ++j) {
302 for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
303 funcs_[j]->mask(i) = !fixed[count];
304 fixedpar_[count] = fixed[count];
305 ++count;
306 }
307 }
308 } else if (dynamic_cast<Lorentzian1D<Float>* >(funcs_[0]) != 0) {
309 uInt count = 0;
310 for (uInt j=0; j < funcs_.nelements(); ++j) {
311 for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
312 funcs_[j]->mask(i) = !fixed[count];
313 fixedpar_[count] = fixed[count];
314 ++count;
315 }
316 }
317 } else if (dynamic_cast<Sinusoid1D<Float>* >(funcs_[0]) != 0) {
318 uInt count = 0;
319 for (uInt j=0; j < funcs_.nelements(); ++j) {
320 for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
321 funcs_[j]->mask(i) = !fixed[count];
322 fixedpar_[count] = fixed[count];
323 ++count;
324 }
325 }
326 } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
327 for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
328 fixedpar_[i] = fixed[i];
329 funcs_[0]->mask(i) = !fixed[i];
330 }
331 }
332 return true;
333}
334
335std::vector<float> Fitter::getParameters() const {
336 Vector<Float> out = parameters_;
337 std::vector<float> stlout;
338 out.tovector(stlout);
339 return stlout;
340}
341
342std::vector<bool> Fitter::getFixedParameters() const {
343 Vector<Bool> out(parameters_.nelements());
344 if (fixedpar_.nelements() == 0) {
345 return std::vector<bool>();
346 //throw (AipsError("No parameter mask set."));
347 } else {
348 out = fixedpar_;
349 }
350 std::vector<bool> stlout;
351 out.tovector(stlout);
352 return stlout;
353}
354
355float Fitter::getChisquared() const {
356 return chisquared_;
357}
358
359bool Fitter::fit() {
360 NonLinearFitLM<Float> fitter;
361 CompoundFunction<Float> func;
362
363 uInt n = funcs_.nelements();
364 for (uInt i=0; i<n; ++i) {
365 func.addFunction(*funcs_[i]);
366 }
367
368 fitter.setFunction(func);
369 fitter.setMaxIter(50+n*10);
370 // Convergence criterium
371 fitter.setCriteria(0.001);
372
373 // Fit
374 Vector<Float> sigma(x_.nelements());
375 sigma = 1.0;
376
377 parameters_.resize();
378 parameters_ = fitter.fit(x_, y_, sigma, &m_);
379 if ( !fitter.converged() ) {
380 return false;
381 }
382 std::vector<float> ps;
383 parameters_.tovector(ps);
384 setParameters(ps);
385
386 error_.resize();
387 error_ = fitter.errors();
388
389 chisquared_ = fitter.getChi2();
390
391 residual_.resize();
392 residual_ = y_;
393 fitter.residual(residual_,x_);
394 // use fitter.residual(model=True) to get the model
395 thefit_.resize(x_.nelements());
396 fitter.residual(thefit_,x_,True);
397 return true;
398}
399
400bool Fitter::lfit() {
401 LinearFit<Float> fitter;
402 CompoundFunction<Float> func;
403
404 uInt n = funcs_.nelements();
405 for (uInt i=0; i<n; ++i) {
406 func.addFunction(*funcs_[i]);
407 }
408
409 fitter.setFunction(func);
410 //fitter.setMaxIter(50+n*10);
411 // Convergence criterium
412 //fitter.setCriteria(0.001);
413
414 // Fit
415 Vector<Float> sigma(x_.nelements());
416 sigma = 1.0;
417
418 parameters_.resize();
419 parameters_ = fitter.fit(x_, y_, sigma, &m_);
420 std::vector<float> ps;
421 parameters_.tovector(ps);
422 setParameters(ps);
423
424 error_.resize();
425 error_ = fitter.errors();
426
427 chisquared_ = fitter.getChi2();
428
429 residual_.resize();
430 residual_ = y_;
431 fitter.residual(residual_,x_);
432 // use fitter.residual(model=True) to get the model
433 thefit_.resize(x_.nelements());
434 fitter.residual(thefit_,x_,True);
435 return true;
436}
437
438std::vector<float> Fitter::evaluate(int whichComp) const
439{
440 std::vector<float> stlout;
441 uInt idx = uInt(whichComp);
442 Float y;
443 if ( idx < funcs_.nelements() ) {
444 for (uInt i=0; i<x_.nelements(); ++i) {
445 y = (*funcs_[idx])(x_[i]);
446 stlout.push_back(float(y));
447 }
448 }
449 return stlout;
450}
451
452STFitEntry Fitter::getFitEntry() const
453{
454 STFitEntry fit;
455 fit.setParameters(getParameters());
456 fit.setErrors(getErrors());
457 fit.setComponents(funccomponents_);
458 fit.setFunctions(funcnames_);
459 fit.setParmasks(getFixedParameters());
460 return fit;
461}
Note: See TracBrowser for help on using the repository browser.