Changeset 2666 for trunk


Ignore:
Timestamp:
10/15/12 15:52:38 (12 years ago)
Author:
Malte Marquarding
Message:

Ticket #173: added setting of constraints on the fitter.

Location:
trunk
Files:
8 edited

Legend:

Unmodified
Added
Removed
  • trunk/python/asapfitter.py

    r2541 r2666  
    33from asap.logging import asaplog, asaplog_post_dec
    44from asap.utils import _n_bools, mask_and
    5 
     5from numpy import ndarray
    66
    77class fitter:
     
    2626        self._selection = None
    2727        self.uselinear = False
     28        self._constraints = []
    2829
    2930    def set_data(self, xdat, ydat, mask=None):
     
    7374        Set the function to be fit.
    7475        Parameters:
    75             poly:     use a polynomial of the order given with nonlinear least squares fit
    76             lpoly:    use polynomial of the order given with linear least squares fit
     76            poly:     use a polynomial of the order given with nonlinear
     77                      least squares fit
     78            lpoly:    use polynomial of the order given with linear least
     79                      squares fit
    7780            gauss:    fit the number of gaussian specified
    7881            lorentz:  fit the number of lorentzian specified
    7982            sinusoid: fit the number of sinusoid specified
    8083        Example:
    81             fitter.set_function(poly=3)  # will fit a 3rd order polynomial via nonlinear method
    82             fitter.set_function(lpoly=3)  # will fit a 3rd order polynomial via linear method
     84            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
     85                                         # via nonlinear method
     86            fitter.set_function(lpoly=3)  # will fit a 3rd order polynomial
     87                                          # via linear method
    8388            fitter.set_function(gauss=2) # will fit two gaussians
    8489            fitter.set_function(lorentz=2) # will fit two lorentzians
     
    117122            self.components = [ 3 for i in range(n) ]
    118123            self.uselinear = False
     124        elif kwargs.has_key('expression'):
     125            self.uselinear = False
     126            raise RuntimeError("Not yet implemented")
    119127        else:
    120128            msg = "Invalid function type."
     
    122130
    123131        self.fitter.setexpression(self.fitfunc,n)
     132        self._constraints = []
    124133        self.fitted = False
    125134        return
     
    147156            raise RuntimeError(msg)
    148157
    149         else:
    150             if self.data is not None:
    151                 self.x = self.data._getabcissa(row)
    152                 self.y = self.data._getspectrum(row)
    153                 #self.mask = mask_and(self.mask, self.data._getmask(row))
    154                 if len(self.x) == len(self.mask):
    155                     self.mask = mask_and(self.mask, self.data._getmask(row))
    156                 else:
    157                     asaplog.push('lengths of data and mask are not the same. preset mask will be ignored')
    158                     asaplog.post('WARN','asapfit.fit')
    159                     self.mask=self.data._getmask(row)
    160                 asaplog.push("Fitting:")
    161                 i = row
    162                 out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),
    163                                                                       self.data.getbeam(i),
    164                                                                       self.data.getif(i),
    165                                                                       self.data.getpol(i),
    166                                                                       self.data.getcycle(i))
    167                 asaplog.push(out,False)
     158        if self.data is not None:
     159            self.x = self.data._getabcissa(row)
     160            self.y = self.data._getspectrum(row)
     161            #self.mask = mask_and(self.mask, self.data._getmask(row))
     162            if len(self.x) == len(self.mask):
     163                self.mask = mask_and(self.mask, self.data._getmask(row))
     164            else:
     165                asaplog.push('lengths of data and mask are not the same. '
     166                             'preset mask will be ignored')
     167                asaplog.post('WARN','asapfit.fit')
     168                self.mask=self.data._getmask(row)
     169            asaplog.push("Fitting:")
     170            i = row
     171            out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (
     172                self.data.getscan(i),
     173                self.data.getbeam(i),
     174                self.data.getif(i),
     175                self.data.getpol(i),
     176                self.data.getcycle(i))
     177           
     178            asaplog.push(out, False)
     179
    168180        self.fitter.setdata(self.x, self.y, self.mask)
    169181        if self.fitfunc == 'gauss' or self.fitfunc == 'lorentz':
     
    174186        if len(fxdpar) and fxdpar.count(0) == 0:
    175187             raise RuntimeError,"No point fitting, if all parameters are fixed."
     188        if self._constraints:
     189            for c in self._constraints:
     190                self.fitter.addconstraint(c[0]+[c[-1]])
    176191        if self.uselinear:
    177192            converged = self.fitter.lfit()
     
    234249            msg = "Please specify a fitting function first."
    235250            raise RuntimeError(msg)
    236         if (self.fitfunc == "gauss" or self.fitfunc == "lorentz" or self.fitfunc == "sinusoid") and component is not None:
     251        if (self.fitfunc == "gauss" or self.fitfunc == "lorentz"
     252            or self.fitfunc == "sinusoid") and component is not None:
    237253            if not self.fitted and sum(self.fitter.getparameters()) == 0:
    238254                pars = _n_bools(len(self.components)*3, False)
     
    338354            raise ValueError(msg)
    339355
     356
     357    def add_constraint(self, xpar, y):
     358        """Add parameter constraints to the fit. This is done by setting up
     359        linear equations for the related parameters.
     360
     361        For example a two component gaussian fit where the amplitudes are
     362        constraint by amp1 = 2*amp2
     363        needs a constraint   
     364
     365            add_constraint([1, 0, 0, -2, 0, 0, 0], 0)
     366
     367        a velocity difference of v2-v1=17
     368
     369            add_constraint([0.,-1.,0.,0.,1.,0.], 17.)
     370
     371        """
     372        self._constraints.append((xpar, y))
     373       
     374
    340375    def get_area(self, component=None):
    341376        """
     
    381416        cerrs = errs
    382417        if component is not None:
    383             if self.fitfunc == "gauss" or self.fitfunc == "lorentz" or self.fitfunc == "sinusoid":
     418            if self.fitfunc == "gauss" or self.fitfunc == "lorentz" \
     419                    or self.fitfunc == "sinusoid":
    384420                i = 3*component
    385421                if i < len(errs):
     
    463499                out += "%s%s = %3.3f %s, " % (pnam[1], fix1, pars[i+1], aunit)
    464500                out += "%s%s = %3.3f %s\n" % (pnam[2], fix2, pars[i+2], aunit)
    465                 if len(area): out += "      area = %3.3f %s %s\n" % (area[i], ounit, aunit)
     501                if len(area): out += "      area = %3.3f %s %s\n" % (area[i],
     502                                                                     ounit,
     503                                                                     aunit)
    466504                c+=1
    467505                i+=3
     
    571609            ylab = self.data._get_ordinate_label()
    572610
    573         colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
     611        colours = ["#777777","#dddddd","red","orange","purple","green",
     612                   "magenta", "cyan"]
    574613        nomask=True
    575614        for i in range(len(m)):
     
    604643            if isinstance(components,int): cs = [components]
    605644            if plotparms:
    606                 self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
     645                self._p.text(0.15,0.15,
     646                             str(self.get_parameters()['formatted']),size=8)
    607647            n = len(self.components)
    608648            self._p.palette(3)
     
    656696        asaplog.push("Fitting:")
    657697        for r in rows:
    658             out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),
    659                                                                    scan.getbeam(r),
    660                                                                    scan.getif(r),
    661                                                                    scan.getpol(r),
    662                                                                    scan.getcycle(r))
     698            out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (
     699                scan.getscan(r),
     700                scan.getbeam(r),
     701                scan.getif(r),
     702                scan.getpol(r),
     703                scan.getcycle(r)
     704                )
    663705            asaplog.push(out, False)
    664706            self.x = scan._getabcissa(r)
     
    668710                self.mask = mask_and(self.mask, self.data._getmask(row))
    669711            else:
    670                 asaplog.push('lengths of data and mask are not the same. preset mask will be ignored')
     712                asaplog.push('lengths of data and mask are not the same. '
     713                             'preset mask will be ignored')
    671714                asaplog.post('WARN','asapfit.fit')
    672715                self.mask=self.data._getmask(row)
  • trunk/src/STFitter.cpp

    r2580 r2666  
    8080  y_.resize();
    8181  m_.resize();
     82  constraints_.clear();
    8283}
    8384
     
    294295}
    295296
     297void Fitter::addConstraint(const std::vector<float>& constraint)
     298{
     299  if (funcs_.nelements() == 0)
     300    throw (AipsError("Function not yet set."));
     301  constraints_.push_back(constraint);
     302 
     303}
     304
     305void Fitter::applyConstraints(GenericL2Fit<Float>& fitter)
     306{
     307  std::vector<std::vector<float> >::const_iterator it;
     308  for (it = constraints_.begin(); it != constraints_.end(); ++it) {
     309    Vector<Float> tmp(*it);
     310    fitter.addConstraint(tmp(Slice(0,tmp.nelements()-1)),
     311                         tmp(tmp.nelements()-1));
     312  }
     313}
     314
    296315bool Fitter::setFixedParameters(std::vector<bool> fixed)
    297316{
     
    377396  // Convergence criterium
    378397  fitter.setCriteria(0.001);
     398  applyConstraints(fitter);
    379399
    380400  // Fit
     
    397417  chisquared_ = fitter.getChi2();
    398418
    399 //   residual_.resize();
    400 //   residual_ =  y_;
    401 //   fitter.residual(residual_,x_);
    402419  // use fitter.residual(model=True) to get the model
    403420  thefit_.resize(x_.nelements());
    404421  fitter.residual(thefit_,x_,True);
    405   // residual = data - model
    406422  residual_.resize(x_.nelements());
    407423  residual_ = y_ - thefit_ ;
     
    419435
    420436  fitter.setFunction(func);
    421   //fitter.setMaxIter(50+n*10);
    422   // Convergence criterium
    423   //fitter.setCriteria(0.001);
    424 
    425   // Fit
    426 //   Vector<Float> sigma(x_.nelements());
    427 //   sigma = 1.0;
     437  applyConstraints(fitter);
    428438
    429439  parameters_.resize();
    430 //   parameters_ = fitter.fit(x_, y_, sigma, &m_);
    431440  parameters_ = fitter.fit(x_, y_, &m_);
    432441  std::vector<float> ps;
     
    439448  chisquared_ = fitter.getChi2();
    440449
    441 //   residual_.resize();
    442 //   residual_ =  y_;
    443 //   fitter.residual(residual_,x_);
    444   // use fitter.residual(model=True) to get the model
    445450  thefit_.resize(x_.nelements());
    446451  fitter.residual(thefit_,x_,True);
    447   // residual = data - model
    448452  residual_.resize(x_.nelements());
    449453  residual_ = y_ - thefit_ ;
  • trunk/src/STFitter.h

    r1932 r2666  
    3939#include <scimath/Functionals/Function.h>
    4040#include <scimath/Functionals/CompoundFunction.h>
     41#include <scimath/Fitting/GenericL2Fit.h>
     42
    4143
    4244#include "STFitEntry.h"
     45
    4346
    4447namespace asap {
     
    5558  bool setParameters(std::vector<float> params);
    5659  bool setFixedParameters(std::vector<bool> fixed);
     60  void addConstraint(const std::vector<float>& constraint);
    5761
    5862  std::vector<float> getResidual() const;
     
    7680private:
    7781  void clear();
     82  void applyConstraints(casa::GenericL2Fit<casa::Float>& fitter);
    7883  casa::Vector<casa::Float> x_;
    7984  casa::Vector<casa::Float> y_;
     
    8792  casa::Vector<casa::Float> parameters_;
    8893  casa::Vector<casa::Bool> fixedpar_;
     94  std::vector<std::vector<float> > constraints_;
    8995
    9096  casa::Vector<casa::Float> error_;
  • trunk/src/STLineFinder.h

    r2580 r2666  
    4848#include "ScantableWrapper.h"
    4949#include "Scantable.h"
    50 #include "STFitter.h"
    5150
    5251namespace asap {
  • trunk/src/Scantable.cpp

    r2658 r2666  
    7171#include "STPolStokes.h"
    7272#include "STUpgrade.h"
     73#include "STFitter.h"
    7374#include "Scantable.h"
    7475
  • trunk/src/Scantable.h

    r2658 r2666  
    4343#include "STFit.h"
    4444#include "STFitEntry.h"
    45 #include "STFitter.h"
     45//#include "STFitter.h"
    4646#include "STFocus.h"
    4747#include "STFrequencies.h"
     
    5555
    5656namespace asap {
     57
     58class Fitter;
    5759
    5860/**
  • trunk/src/python_Fitter.cpp

    r1391 r2666  
    4949        .def("getparameters", &Fitter::getParameters)
    5050        .def("setparameters", &Fitter::setParameters)
     51        .def("addconstraint", &Fitter::addConstraint)
    5152        .def("getestimate", &Fitter::getEstimate)
    5253        .def("estimate", &Fitter::computeEstimate)
  • trunk/src/python_asap.cpp

    r2658 r2666  
    124124  casa::pyrap::register_convert_casa_valueholder();
    125125  casa::pyrap::register_convert_casa_record();
     126
    126127#endif
    127128}
Note: See TracChangeset for help on using the changeset viewer.