source: branches/Release-1-fixes/src/SDFitter.cc @ 610

Last change on this file since 610 was 610, checked in by mar637, 19 years ago

Fix for asap0017

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 8.9 KB
Line 
1//#---------------------------------------------------------------------------
2//# SDFitter.cc: A Fitter class for spectra
3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004
5//# ATNF
6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
29//# $Id:
30//#---------------------------------------------------------------------------
31#include <casa/aips.h>
32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
34#include <scimath/Fitting.h>
35#include <scimath/Fitting/LinearFit.h>
36#include <scimath/Functionals/CompiledFunction.h>
37#include <scimath/Functionals/CompoundFunction.h>
38#include <scimath/Functionals/Gaussian1D.h>
39#include <scimath/Functionals/Polynomial.h>
40#include <scimath/Mathematics/AutoDiff.h>
41#include <scimath/Mathematics/AutoDiffMath.h>
42#include <scimath/Fitting/NonLinearFitLM.h>
43#include <components/SpectralComponents/SpectralEstimate.h>
44
45#include "SDFitter.h"
46using namespace asap;
47using namespace casa;
48
49SDFitter::SDFitter()
50{
51}
52
53SDFitter::~SDFitter()
54{
55  reset();
56}
57
58void SDFitter::clear()
59{
60  for (uInt i=0;i< funcs_.nelements();++i) {
61    delete funcs_[i]; funcs_[i] = 0;
62  }
63  funcs_.resize(0,True);
64  parameters_.resize();
65  error_.resize();
66  thefit_.resize();
67  estimate_.resize();
68  chisquared_ = 0.0;
69}
70
71void SDFitter::reset()
72{
73  clear();
74  x_.resize();
75  y_.resize();
76  m_.resize();
77}
78
79
80bool SDFitter::computeEstimate() {
81  if (x_.nelements() == 0 || y_.nelements() == 0)
82    throw (AipsError("No x/y data specified."));
83
84  if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
85    return false;
86  uInt n = funcs_.nelements();
87  SpectralEstimate estimator(n);
88  estimator.setQ(5);
89  Int mn,mx;
90  mn = 0;
91  mx = m_.nelements()-1;
92  for (uInt i=0; i<m_.nelements();++i) {
93    if (m_[i]) {
94      mn = i;
95      break;
96    }
97  }
98  for (uInt j=m_.nelements()-1; j>=0;--j) {
99    if (m_[j]) {
100      mx = j;
101      break;
102    }
103  }
104  mn = 0+x_.nelements()/10;
105  mx = x_.nelements()-x_.nelements()/10;
106  estimator.setRegion(mn,mx);
107  //estimator.setWindowing(True);
108  SpectralList listGauss = estimator.estimate(x_, y_);
109  parameters_.resize(n*3);
110  Gaussian1D<Float>* g = 0;
111  for (uInt i=0; i<n;i++) {
112    g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
113    if (g) {
114      (*g)[0] = listGauss[i].getAmpl();
115      (*g)[1] = listGauss[i].getCenter();
116      (*g)[2] = listGauss[i].getFWHM();
117    }
118  }
119  estimate_.resize();
120  listGauss.evaluate(estimate_,x_);
121  return true;
122}
123
124std::vector<float> SDFitter::getEstimate() const
125{
126  if (estimate_.nelements() == 0)
127    throw (AipsError("No estimate set."));
128  std::vector<float> stlout;
129  estimate_.tovector(stlout);
130  return stlout;
131}
132
133
134bool SDFitter::setExpression(const std::string& expr, int ncomp)
135{
136  clear();
137  if (expr == "gauss") {
138    if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
139    funcs_.resize(ncomp);
140    for (Int k=0; k<ncomp; ++k) {
141      funcs_[k] = new Gaussian1D<Float>();
142    }
143  } else if (expr == "poly") {
144    funcs_.resize(1);
145    funcs_[0] = new Polynomial<Float>(ncomp);
146  } else {
147    cerr << " compiled functions not yet implemented" << endl;
148    //funcs_.resize(1);
149    //funcs_[0] = new CompiledFunction<Float>();
150    //funcs_[0]->setFunction(String(expr));
151    return false;
152  }
153  return true;
154}
155
156bool SDFitter::setData(std::vector<float> absc, std::vector<float> spec,
157                       std::vector<bool> mask)
158{
159    x_.resize();
160    y_.resize();
161    m_.resize();
162    // convert std::vector to casa Vector
163    Vector<Float> tmpx(absc);
164    Vector<Float> tmpy(spec);
165    Vector<Bool> tmpm(mask);
166    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
167    x_ = tmpx;
168    y_ = tmpy;
169    m_ = tmpm;
170    return true;
171}
172
173std::vector<float> SDFitter::getResidual() const
174{
175    if (residual_.nelements() == 0)
176        throw (AipsError("Function not yet fitted."));
177    std::vector<float> stlout;
178    residual_.tovector(stlout);
179    return stlout;
180}
181
182std::vector<float> SDFitter::getFit() const
183{
184    Vector<Float> out = thefit_;
185    std::vector<float> stlout;
186    out.tovector(stlout);
187    return stlout;
188
189}
190
191std::vector<float> SDFitter::getErrors() const
192{
193    Vector<Float> out = error_;
194    std::vector<float> stlout;
195    out.tovector(stlout);
196    return stlout;
197}
198
199bool SDFitter::setParameters(std::vector<float> params)
200{
201    Vector<Float> tmppar(params);
202    if (funcs_.nelements() == 0)
203        throw (AipsError("Function not yet set."));
204    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
205        throw (AipsError("Number of parameters inconsistent with function."));
206    if (parameters_.nelements() == 0)
207        parameters_.resize(tmppar.nelements());
208        fixedpar_.resize(tmppar.nelements());
209        fixedpar_ = False;
210    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
211        uInt count = 0;
212        for (uInt j=0; j < funcs_.nelements(); ++j) {
213            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
214                (funcs_[j]->parameters())[i] = tmppar[count];
215                parameters_[count] = tmppar[count];
216                ++count;
217            }
218        }
219    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
220        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
221            parameters_[i] = tmppar[i];
222            (funcs_[0]->parameters())[i] =  tmppar[i];
223        }
224    }
225    return true;
226}
227
228bool SDFitter::setFixedParameters(std::vector<bool> fixed)
229{
230    Vector<Bool> tmp(fixed);
231    if (funcs_.nelements() == 0)
232        throw (AipsError("Function not yet set."));
233    if (fixedpar_.nelements() > 0 && tmp.nelements() != fixedpar_.nelements())
234        throw (AipsError("Number of mask elements inconsistent with function."));
235    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
236        uInt count = 0;
237        for (uInt j=0; j < funcs_.nelements(); ++j) {
238            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
239                funcs_[j]->mask(i) = !tmp[count];
240                fixedpar_[count] = !tmp[count];
241                ++count;
242            }
243        }
244    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
245        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
246            fixedpar_[i] = tmp[i];
247            funcs_[0]->mask(i) =  tmp[i];
248        }
249    }
250    //fixedpar_ = !tmpmsk;
251    return true;
252}
253
254std::vector<float> SDFitter::getParameters() const {
255    Vector<Float> out = parameters_;
256    std::vector<float> stlout;
257    out.tovector(stlout);
258    return stlout;
259}
260
261std::vector<bool> SDFitter::getFixedParameters() const {
262  Vector<Bool> out(parameters_.nelements());
263  if (fixedpar_.nelements() == 0) {
264    out = False;
265    //throw (AipsError("No parameter mask set."));
266  } else {
267    out = fixedpar_;
268  }
269  std::vector<bool> stlout;
270  out.tovector(stlout);
271  return stlout;
272}
273
274float SDFitter::getChisquared() const {
275    return chisquared_;
276}
277
278bool SDFitter::fit() {
279  NonLinearFitLM<Float> fitter;
280  CompoundFunction<Float> func;
281
282  uInt n = funcs_.nelements();
283  for (uInt i=0; i<n; ++i) {
284    func.addFunction(*funcs_[i]);
285  }
286
287  fitter.setFunction(func);
288  fitter.setMaxIter(50+n*10);
289  // Convergence criterium
290  fitter.setCriteria(0.001);
291
292  // Fit
293  Vector<Float> sigma(x_.nelements());
294  sigma = 1.0;
295 
296  parameters_.resize();
297  parameters_ = fitter.fit(x_, y_, sigma, &m_);
298  std::vector<float> ps;
299  parameters_.tovector(ps);
300  setParameters(ps);
301
302  error_.resize();
303  error_ = fitter.errors();
304
305  chisquared_ = fitter.getChi2();
306 
307  residual_.resize();
308  residual_ =  y_;
309  fitter.residual(residual_,x_);
310
311  // use fitter.residual(model=True) to get the model
312  thefit_.resize(x_.nelements());
313  fitter.residual(thefit_,x_,True);
314  return true;
315}
316
317
318std::vector<float> SDFitter::evaluate(int whichComp) const
319
320  std::vector<float> stlout;
321  uInt idx = uInt(whichComp);
322  Float y;
323  if ( idx < funcs_.nelements() ) {
324    for (uInt i=0; i<x_.nelements(); ++i) {
325      y = (*funcs_[idx])(x_[i]);
326      stlout.push_back(float(y));
327    }
328  }
329  return stlout;
330}
331
Note: See TracBrowser for help on using the repository browser.