source: trunk/src/SDFitter.cc @ 483

Last change on this file since 483 was 483, checked in by mar637, 19 years ago
  • Added history support.
  • Added version keyword to SDMemTable.
  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 10.0 KB
Line 
1//#---------------------------------------------------------------------------
2//# SDFitter.cc: A Fitter class for spectra
3//#--------------------------------------------------------------------------
4//# Copyright (C) 2004
5//# ATNF
6//#
7//# This program is free software; you can redistribute it and/or modify it
8//# under the terms of the GNU General Public License as published by the Free
9//# Software Foundation; either version 2 of the License, or (at your option)
10//# any later version.
11//#
12//# This program is distributed in the hope that it will be useful, but
13//# WITHOUT ANY WARRANTY; without even the implied warranty of
14//# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
15//# Public License for more details.
16//#
17//# You should have received a copy of the GNU General Public License along
18//# with this program; if not, write to the Free Software Foundation, Inc.,
19//# 675 Massachusetts Ave, Cambridge, MA 02139, USA.
20//#
21//# Correspondence concerning this software should be addressed as follows:
22//#        Internet email: Malte.Marquarding@csiro.au
23//#        Postal address: Malte Marquarding,
24//#                        Australia Telescope National Facility,
25//#                        P.O. Box 76,
26//#                        Epping, NSW, 2121,
27//#                        AUSTRALIA
28//#
29//# $Id:
30//#---------------------------------------------------------------------------
31#include <casa/aips.h>
32#include <casa/Arrays/ArrayMath.h>
33#include <casa/Arrays/ArrayLogical.h>
34#include <scimath/Fitting.h>
35#include <scimath/Fitting/LinearFit.h>
36#include <scimath/Functionals/CompiledFunction.h>
37#include <scimath/Functionals/CompoundFunction.h>
38#include <scimath/Functionals/Gaussian1D.h>
39#include <scimath/Functionals/Polynomial.h>
40#include <scimath/Mathematics/AutoDiff.h>
41#include <scimath/Mathematics/AutoDiffMath.h>
42#include <scimath/Fitting/NonLinearFitLM.h>
43#include <components/SpectralComponents/SpectralEstimate.h>
44
45#include "SDFitter.h"
46using namespace asap;
47using namespace casa;
48
49SDFitter::SDFitter()
50{
51}
52
53SDFitter::~SDFitter()
54{
55    reset();
56}
57
58void SDFitter::clear()
59{
60    for (uInt i=0;i< funcs_.nelements();++i) {
61        delete funcs_[i]; funcs_[i] = 0;
62    }
63    funcs_.resize(0, True);
64    parameters_.resize();
65    error_.resize();
66    thefit_.resize();
67    estimate_.resize();
68    chisquared_ = 0.0;
69}
70void SDFitter::reset()
71{
72    clear();
73    x_.resize();
74    y_.resize();
75    m_.resize();
76}
77
78
79bool SDFitter::computeEstimate() {
80    if (x_.nelements() == 0 || y_.nelements() == 0)
81        throw (AipsError("No x/y data specified."));
82
83    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) == 0)
84        return false;
85    uInt n = funcs_.nelements();
86    SpectralEstimate estimator(n);
87    estimator.setQ(5);
88    Int mn,mx;
89    mn = 0;
90    mx = m_.nelements()-1;
91    for (uInt i=0; i<m_.nelements();++i) {
92      if (m_[i]) {
93        mn = i;
94        break;
95      }
96    }
97    for (uInt j=m_.nelements()-1; j>=0;--j) {
98      if (m_[j]) {
99        mx = j;
100        break;
101      }
102    }
103    //mn = 0+x_.nelements()/10;
104    //mx = x_.nelements()-x_.nelements()/10;
105    estimator.setRegion(mn,mx);
106    //estimator.setWindowing(True);
107    SpectralList listGauss = estimator.estimate(x_, y_);
108    Gaussian1D<Float>* g;
109    parameters_.resize(n*3);
110    uInt count = 0;
111    for (uInt i=0; i<n;i++) {
112        g = dynamic_cast<Gaussian1D<Float>* >(funcs_[i]);
113        if (g) {
114            (*g)[0] = listGauss[i].getAmpl();
115            (*g)[1] = listGauss[i].getCenter();
116            (*g)[2] = listGauss[i].getFWHM();
117            ++count;
118        }
119    }
120    estimate_.resize();
121    listGauss.evaluate(estimate_,x_);
122    return true;
123}
124
125std::vector<float> SDFitter::getEstimate() const
126{
127    if (estimate_.nelements() == 0)
128        throw (AipsError("No estimate set."));
129    std::vector<float> stlout;
130    estimate_.tovector(stlout);
131    return stlout;
132}
133
134
135bool SDFitter::setExpression(const std::string& expr, int ncomp)
136{
137    clear();
138    if (expr == "gauss") {
139        if (ncomp < 1) throw (AipsError("Need at least one gaussian to fit."));
140        funcs_.resize(ncomp);
141        for (Int k=0; k<ncomp; ++k) {
142            funcs_[k] = new Gaussian1D<Float>();
143        }
144    } else if (expr == "poly") {
145        funcs_.resize(1);
146        funcs_[0] = new Polynomial<Float>(ncomp);
147    } else {
148        //cerr << " compiled functions not yet implemented" << endl;
149        //funcs_.resize(1);
150        //funcs_[0] = new CompiledFunction<Float>();
151        //funcs_[0]->setFunction(String(expr));
152        return false;
153    };
154    return true;
155}
156
157bool SDFitter::setData(std::vector<float> absc, std::vector<float> spec,
158                       std::vector<bool> mask)
159{
160    x_.resize();
161    y_.resize();
162    m_.resize();
163    // convert std::vector to casa Vector
164    Vector<Float> tmpx(absc);
165    Vector<Float> tmpy(spec);
166    Vector<Bool> tmpm(mask);
167    AlwaysAssert(tmpx.nelements() == tmpy.nelements(), AipsError);
168    x_ = tmpx;
169    y_ = tmpy;
170    m_ = tmpm;
171    return true;
172}
173
174std::vector<float> SDFitter::getResidual() const
175{
176    if (residual_.nelements() == 0)
177        throw (AipsError("Function not yet fitted."));
178    std::vector<float> stlout;
179    residual_.tovector(stlout);
180    return stlout;
181}
182
183std::vector<float> SDFitter::getFit() const
184{
185    Vector<Float> out = thefit_;
186    std::vector<float> stlout;
187    out.tovector(stlout);
188    return stlout;
189
190}
191
192std::vector<float> SDFitter::getErrors() const
193{
194    Vector<Float> out = error_;
195    std::vector<float> stlout;
196    out.tovector(stlout);
197    return stlout;
198}
199
200bool SDFitter::setParameters(std::vector<float> params)
201{
202    Vector<Float> tmppar(params);
203    if (funcs_.nelements() == 0)
204        throw (AipsError("Function not yet set."));
205    if (parameters_.nelements() > 0 && tmppar.nelements() != parameters_.nelements())
206        throw (AipsError("Number of parameters inconsistent with function."));
207    if (parameters_.nelements() == 0)
208        parameters_.resize(tmppar.nelements());
209        fixedpar_.resize(tmppar.nelements());
210        fixedpar_ = False;
211    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
212        uInt count = 0;
213        for (uInt j=0; j < funcs_.nelements(); ++j) {
214            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
215                (funcs_[j]->parameters())[i] = tmppar[count];
216                parameters_[count] = tmppar[count];
217                ++count;
218            }
219        }
220    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
221        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
222            parameters_[i] = tmppar[i];
223            (funcs_[0]->parameters())[i] =  tmppar[i];
224        }
225    }
226    return true;
227}
228
229bool SDFitter::setFixedParameters(std::vector<bool> fixed)
230{
231    Vector<Bool> tmp(fixed);
232    if (funcs_.nelements() == 0)
233        throw (AipsError("Function not yet set."));
234    if (fixedpar_.nelements() > 0 && tmp.nelements() != fixedpar_.nelements())
235        throw (AipsError("Number of mask elements inconsistent with function."));
236    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
237        uInt count = 0;
238        for (uInt j=0; j < funcs_.nelements(); ++j) {
239            for (uInt i=0; i < funcs_[j]->nparameters(); ++i) {
240                funcs_[j]->mask(i) = !tmp[count];
241                fixedpar_[count] = !tmp[count];
242                ++count;
243            }
244        }
245    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
246        for (uInt i=0; i < funcs_[0]->nparameters(); ++i) {
247            fixedpar_[i] = tmp[i];
248            funcs_[0]->mask(i) =  tmp[i];
249        }
250    }
251    //fixedpar_ = !tmpmsk;
252    return true;
253}
254
255std::vector<float> SDFitter::getParameters() const {
256    Vector<Float> out = parameters_;
257    std::vector<float> stlout;
258    out.tovector(stlout);
259    return stlout;
260}
261
262std::vector<bool> SDFitter::getFixedParameters() const {
263  Vector<Bool> out(parameters_.nelements());
264  if (fixedpar_.nelements() == 0) {
265    out = False;
266    //throw (AipsError("No parameter mask set."));
267  } else {
268    out = fixedpar_;
269  }
270  std::vector<bool> stlout;
271  out.tovector(stlout);
272  return stlout;
273}
274
275float SDFitter::getChisquared() const {
276    return chisquared_;
277}
278
279bool SDFitter::fit() {
280    NonLinearFitLM<Float> fitter;
281    //CompiledFunction<AutoDiff<Float> > comp;
282    //Polynomial<AutoDiff<Float> > poly;
283    CompoundFunction<AutoDiff<Float> > func;
284    if (dynamic_cast<Gaussian1D<Float>* >(funcs_[0]) != 0) {
285        //computeEstimates();
286        for (uInt i=0; i<funcs_.nelements(); i++) {
287            Gaussian1D<AutoDiff<Float> > gauss;//(*funcs_[i]);
288           
289            for (uInt j=0; j<funcs_[i]->nparameters(); j++) {
290                gauss[j] = AutoDiff<Float>((*funcs_[i])[j],
291                                           gauss.nparameters(), j);
292                gauss.mask(j) = funcs_[i]->mask(j);
293            }
294           
295            func.addFunction(gauss);
296        }
297    } else if (dynamic_cast<Polynomial<Float>* >(funcs_[0]) != 0) {
298        Polynomial<AutoDiff<Float> > poly(funcs_[0]->nparameters()-1);
299        //Polynomial<AutoDiff<Float> > poly(*funcs_[0]);
300        for (uInt j=0; j<funcs_[0]->nparameters(); j++) {
301            poly[j] = AutoDiff<Float>(0, poly.nparameters(), j);
302            poly.mask(j) = funcs_[0]->mask(j);
303        }
304        func.addFunction(poly);
305    } else if (dynamic_cast<CompiledFunction<Float>* >(funcs_[0]) != 0) {
306       
307        //         CompiledFunction<AutoDiff<Float> > comp;
308        //         for (uInt j=0; j<funcs_[0]->nparameters(); j++) {
309//             comp[j] = AutoDiff<Float>(0, comp.nparameters(), j);
310//             comp.mask(j) = funcs_[0]->mask(j);
311//         }
312//         func.addFunction(comp);
313       
314        cout << "NYI." << endl;
315    } else {
316        throw(AipsError("Fitter not set up correctly."));
317    }
318    fitter.setFunction(func);
319    fitter.setMaxIter(50+funcs_.nelements()*10);
320    // Convergence criterium
321    fitter.setCriteria(0.001);
322
323    // Fit
324    Vector<Float> sigma(x_.nelements());
325    sigma = 1.0;
326
327    parameters_.resize();
328    parameters_ = fitter.fit(x_, y_, sigma, &m_);
329
330    error_.resize();
331    error_ = fitter.errors();
332
333    chisquared_ = fitter.getChi2();
334
335    residual_.resize();
336    residual_ =  y_;
337    fitter.residual(residual_,x_);
338
339    // use fitter.residual(model=True) to get the model
340    thefit_.resize(x_.nelements());
341    fitter.residual(thefit_,x_,True);
342    return true;
343}
Note: See TracBrowser for help on using the repository browser.