source: trunk/python/asapfitter.py @ 1420

Last change on this file since 1420 was 1420, checked in by Malte Marquarding, 16 years ago

typo fix

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 21.2 KB
Line 
1import _asap
2from asap import rcParams
3from asap import print_log
4from asap import _n_bools
5
6class fitter:
7    """
8    The fitting class for ASAP.
9    """
10
11    def __init__(self):
12        """
13        Create a fitter object. No state is set.
14        """
15        self.fitter = _asap.fitter()
16        self.x = None
17        self.y = None
18        self.mask = None
19        self.fitfunc = None
20        self.fitfuncs = None
21        self.fitted = False
22        self.data = None
23        self.components = 0
24        self._fittedrow = 0
25        self._p = None
26        self._selection = None
27        self.uselinear = False
28
29    def set_data(self, xdat, ydat, mask=None):
30        """
31        Set the absissa and ordinate for the fit. Also set the mask
32        indicationg valid points.
33        This can be used for data vectors retrieved from a scantable.
34        For scantable fitting use 'fitter.set_scan(scan, mask)'.
35        Parameters:
36            xdat:    the abcissa values
37            ydat:    the ordinate values
38            mask:    an optional mask
39
40        """
41        self.fitted = False
42        self.x = xdat
43        self.y = ydat
44        if mask == None:
45            self.mask = _n_bools(len(xdat), True)
46        else:
47            self.mask = mask
48        return
49
50    def set_scan(self, thescan=None, mask=None):
51        """
52        Set the 'data' (a scantable) of the fitter.
53        Parameters:
54            thescan:     a scantable
55            mask:        a msk retrieved from the scantable
56        """
57        if not thescan:
58            msg = "Please give a correct scan"
59            if rcParams['verbose']:
60                print msg
61                return
62            else:
63                raise TypeError(msg)
64        self.fitted = False
65        self.data = thescan
66        self.mask = None
67        if mask is None:
68            self.mask = _n_bools(self.data.nchan(), True)
69        else:
70            self.mask = mask
71        return
72
73    def set_function(self, **kwargs):
74        """
75        Set the function to be fit.
76        Parameters:
77            poly:    use a polynomial of the order given with nonlinear least squares fit
78            lpoly:   use polynomial of the order given with linear least squares fit
79            gauss:   fit the number of gaussian specified
80        Example:
81            fitter.set_function(gauss=2) # will fit two gaussians
82            fitter.set_function(poly=3)  # will fit a 3rd order polynomial via nonlinear method
83            fitter.set_function(lpoly=3)  # will fit a 3rd order polynomial via linear method
84        """
85        #default poly order 0
86        n=0
87        if kwargs.has_key('poly'):
88            self.fitfunc = 'poly'
89            n = kwargs.get('poly')
90            self.components = [n]
91            self.uselinear = False
92        elif kwargs.has_key('lpoly'):
93            self.fitfunc = 'poly'
94            n = kwargs.get('lpoly')
95            self.components = [n]
96            self.uselinear = True
97        elif kwargs.has_key('gauss'):
98            n = kwargs.get('gauss')
99            self.fitfunc = 'gauss'
100            self.fitfuncs = [ 'gauss' for i in range(n) ]
101            self.components = [ 3 for i in range(n) ]
102            self.uselinear = False
103        else:
104            msg = "Invalid function type."
105            if rcParams['verbose']:
106                print msg
107                return
108            else:
109                raise TypeError(msg)
110
111        self.fitter.setexpression(self.fitfunc,n)
112        self.fitted = False
113        return
114
115    def fit(self, row=0, estimate=False):
116        """
117        Execute the actual fitting process. All the state has to be set.
118        Parameters:
119            row:        specify the row in the scantable
120            estimate:   auto-compute an initial parameter set (default False)
121                        This can be used to compute estimates even if fit was
122                        called before.
123        Example:
124            s = scantable('myscan.asap')
125            s.set_cursor(thepol=1)        # select second pol
126            f = fitter()
127            f.set_scan(s)
128            f.set_function(poly=0)
129            f.fit(row=0)                  # fit first row
130        """
131        if ((self.x is None or self.y is None) and self.data is None) \
132               or self.fitfunc is None:
133            msg = "Fitter not yet initialised. Please set data & fit function"
134            if rcParams['verbose']:
135                print msg
136                return
137            else:
138                raise RuntimeError(msg)
139
140        else:
141            if self.data is not None:
142                self.x = self.data._getabcissa(row)
143                self.y = self.data._getspectrum(row)
144                from asap import asaplog
145                asaplog.push("Fitting:")
146                i = row
147                out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),self.data.getbeam(i),self.data.getif(i),self.data.getpol(i), self.data.getcycle(i))
148                asaplog.push(out,False)
149        self.fitter.setdata(self.x, self.y, self.mask)
150        if self.fitfunc == 'gauss':
151            ps = self.fitter.getparameters()
152            if len(ps) == 0 or estimate:
153                self.fitter.estimate()
154        try:
155            fxdpar = list(self.fitter.getfixedparameters())
156            if len(fxdpar) and fxdpar.count(0) == 0:
157                 raise RuntimeError,"No point fitting, if all parameters are fixed."
158            if self.uselinear:
159                converged = self.fitter.lfit()
160            else:
161                converged = self.fitter.fit()
162            if not converged:
163                raise RuntimeError,"Fit didn't converge."
164        except RuntimeError, msg:
165            if rcParams['verbose']:
166                print msg
167            else:
168                raise
169        self._fittedrow = row
170        self.fitted = True
171        print_log()
172        return
173
174    def store_fit(self, filename=None):
175        """
176        Save the fit parameters.
177        Parameters:
178            filename:    if specified save as an ASCII file, if None (default)
179                         store it in the scnatable
180        """
181        if self.fitted and self.data is not None:
182            pars = list(self.fitter.getparameters())
183            fixed = list(self.fitter.getfixedparameters())
184            from asap.asapfit import asapfit
185            fit = asapfit()
186            fit.setparameters(pars)
187            fit.setfixedparameters(fixed)
188            fit.setfunctions(self.fitfuncs)
189            fit.setcomponents(self.components)
190            fit.setframeinfo(self.data._getcoordinfo())
191            if filename is not None:
192                import os
193                filename = os.path.expandvars(os.path.expanduser(filename))
194                if os.path.exists(filename):
195                    raise IOError("File '%s' exists." % filename)
196                fit.save(filename)
197            else:
198                self.data._addfit(fit,self._fittedrow)
199
200    #def set_parameters(self, params, fixed=None, component=None):
201    def set_parameters(self,*args,**kwargs):
202        """
203        Set the parameters to be fitted.
204        Parameters:
205              params:    a vector of parameters
206              fixed:     a vector of which parameters are to be held fixed
207                         (default is none)
208              component: in case of multiple gaussians, the index of the
209                         component
210        """
211        component = None
212        fixed = None
213        params = None
214
215        if len(args) and isinstance(args[0],dict):
216            kwargs = args[0]
217        if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
218        if kwargs.has_key("params"): params = kwargs["params"]
219        if len(args) == 2 and isinstance(args[1], int):
220            component = args[1]
221        if self.fitfunc is None:
222            msg = "Please specify a fitting function first."
223            if rcParams['verbose']:
224                print msg
225                return
226            else:
227                raise RuntimeError(msg)
228        if self.fitfunc == "gauss" and component is not None:
229            if not self.fitted and sum(self.fitter.getparameters()) == 0:
230                pars = _n_bools(len(self.components)*3, False)
231                fxd = _n_bools(len(pars), False)
232            else:
233                pars = list(self.fitter.getparameters())
234                fxd = list(self.fitter.getfixedparameters())
235            i = 3*component
236            pars[i:i+3] = params
237            fxd[i:i+3] = fixed
238            params = pars
239            fixed = fxd
240        self.fitter.setparameters(params)
241        if fixed is not None:
242            self.fitter.setfixedparameters(fixed)
243        print_log()
244        return
245
246    def set_gauss_parameters(self, peak, centre, fwhm,
247                             peakfixed=0, centrefixed=0,
248                             fwhmfixed=0,
249                             component=0):
250        """
251        Set the Parameters of a 'Gaussian' component, set with set_function.
252        Parameters:
253            peak, centre, fwhm:  The gaussian parameters
254            peakfixed,
255            centrefixed,
256            fwhmfixed:           Optional parameters to indicate if
257                                 the paramters should be held fixed during
258                                 the fitting process. The default is to keep
259                                 all parameters flexible.
260            component:           The number of the component (Default is the
261                                 component 0)
262        """
263        if self.fitfunc != "gauss":
264            msg = "Function only operates on Gaussian components."
265            if rcParams['verbose']:
266                print msg
267                return
268            else:
269                raise ValueError(msg)
270        if 0 <= component < len(self.components):
271            d = {'params':[peak, centre, fwhm],
272                 'fixed':[peakfixed, centrefixed, fwhmfixed]}
273            self.set_parameters(d, component)
274        else:
275            msg = "Please select a valid  component."
276            if rcParams['verbose']:
277                print msg
278                return
279            else:
280                raise ValueError(msg)
281
282    def get_area(self, component=None):
283        """
284        Return the area under the fitted gaussian component.
285        Parameters:
286              component:   the gaussian component selection,
287                           default (None) is the sum of all components
288        Note:
289              This will only work for gaussian fits.
290        """
291        if not self.fitted: return
292        if self.fitfunc == "gauss":
293            pars = list(self.fitter.getparameters())
294            from math import log,pi,sqrt
295            fac = sqrt(pi/log(16.0))
296            areas = []
297            for i in range(len(self.components)):
298                j = i*3
299                cpars = pars[j:j+3]
300                areas.append(fac * cpars[0] * cpars[2])
301        else:
302            return None
303        if component is not None:
304            return areas[component]
305        else:
306            return sum(areas)
307
308    def get_errors(self, component=None):
309        """
310        Return the errors in the parameters.
311        Parameters:
312            component:    get the errors for the specified component
313                          only, default is all components
314        """
315        if not self.fitted:
316            msg = "Not yet fitted."
317            if rcParams['verbose']:
318                print msg
319                return
320            else:
321                raise RuntimeError(msg)
322        errs = list(self.fitter.geterrors())
323        cerrs = errs
324        if component is not None:
325            if self.fitfunc == "gauss":
326                i = 3*component
327                if i < len(errs):
328                    cerrs = errs[i:i+3]
329        return cerrs
330
331    def get_parameters(self, component=None, errors=False):
332        """
333        Return the fit paramters.
334        Parameters:
335             component:    get the parameters for the specified component
336                           only, default is all components
337        """
338        if not self.fitted:
339            msg = "Not yet fitted."
340            if rcParams['verbose']:
341                print msg
342                return
343            else:
344                raise RuntimeError(msg)
345        pars = list(self.fitter.getparameters())
346        fixed = list(self.fitter.getfixedparameters())
347        errs = list(self.fitter.geterrors())
348        area = []
349        if component is not None:
350            if self.fitfunc == "gauss":
351                i = 3*component
352                cpars = pars[i:i+3]
353                cfixed = fixed[i:i+3]
354                cerrs = errs[i:i+3]
355                a = self.get_area(component)
356                area = [a for i in range(3)]
357            else:
358                cpars = pars
359                cfixed = fixed
360                cerrs = errs
361        else:
362            cpars = pars
363            cfixed = fixed
364            cerrs = errs
365            if self.fitfunc == "gauss":
366                for c in range(len(self.components)):
367                  a = self.get_area(c)
368                  area += [a for i in range(3)]
369        fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
370        if rcParams['verbose']:
371            print fpars
372        return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
373                'errors':cerrs}
374
375    def _format_pars(self, pars, fixed, errors, area):
376        out = ''
377        if self.fitfunc == 'poly':
378            c = 0
379            for i in range(len(pars)):
380                fix = ""
381                if len(fixed) and fixed[i]: fix = "(fixed)"
382                if errors :
383                    out += '  p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
384                else:
385                    out += '  p%d%s= %3.6f,' % (c,fix,pars[i])
386                c+=1
387            out = out[:-1]  # remove trailing ','
388        elif self.fitfunc == 'gauss':
389            i = 0
390            c = 0
391            aunit = ''
392            ounit = ''
393            if self.data:
394                aunit = self.data.get_unit()
395                ounit = self.data.get_fluxunit()
396            while i < len(pars):
397                if len(area):
398                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n      area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
399                else:
400                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
401                c+=1
402                i+=3
403        return out
404
405    def get_estimate(self):
406        """
407        Return the parameter estimates (for non-linear functions).
408        """
409        pars = self.fitter.getestimate()
410        fixed = self.fitter.getfixedparameters()
411        if rcParams['verbose']:
412            print self._format_pars(pars,fixed,None)
413        return pars
414
415    def get_residual(self):
416        """
417        Return the residual of the fit.
418        """
419        if not self.fitted:
420            msg = "Not yet fitted."
421            if rcParams['verbose']:
422                print msg
423                return
424            else:
425                raise RuntimeError(msg)
426        return self.fitter.getresidual()
427
428    def get_chi2(self):
429        """
430        Return chi^2.
431        """
432        if not self.fitted:
433            msg = "Not yet fitted."
434            if rcParams['verbose']:
435                print msg
436                return
437            else:
438                raise RuntimeError(msg)
439        ch2 = self.fitter.getchi2()
440        if rcParams['verbose']:
441            print 'Chi^2 = %3.3f' % (ch2)
442        return ch2
443
444    def get_fit(self):
445        """
446        Return the fitted ordinate values.
447        """
448        if not self.fitted:
449            msg = "Not yet fitted."
450            if rcParams['verbose']:
451                print msg
452                return
453            else:
454                raise RuntimeError(msg)
455        return self.fitter.getfit()
456
457    def commit(self):
458        """
459        Return a new scan where the fits have been commited (subtracted)
460        """
461        if not self.fitted:
462            msg = "Not yet fitted."
463            if rcParams['verbose']:
464                print msg
465                return
466            else:
467                raise RuntimeError(msg)
468        from asap import scantable
469        if not isinstance(self.data, scantable):
470            msg = "Not a scantable"
471            if rcParams['verbose']:
472                print msg
473                return
474            else:
475                raise TypeError(msg)
476        scan = self.data.copy()
477        scan._setspectrum(self.fitter.getresidual())
478        print_log()
479        return scan
480
481    def plot(self, residual=False, components=None, plotparms=False, filename=None):
482        """
483        Plot the last fit.
484        Parameters:
485            residual:    an optional parameter indicating if the residual
486                         should be plotted (default 'False')
487            components:  a list of components to plot, e.g [0,1],
488                         -1 plots the total fit. Default is to only
489                         plot the total fit.
490            plotparms:   Inidicates if the parameter values should be present
491                         on the plot
492        """
493        if not self.fitted:
494            return
495        if not self._p or self._p.is_dead:
496            if rcParams['plotter.gui']:
497                from asap.asaplotgui import asaplotgui as asaplot
498            else:
499                from asap.asaplot import asaplot
500            self._p = asaplot()
501        self._p.hold()
502        self._p.clear()
503        self._p.set_panels()
504        self._p.palette(0)
505        tlab = 'Spectrum'
506        xlab = 'Abcissa'
507        ylab = 'Ordinate'
508        from matplotlib.numerix import ma,logical_not,logical_and,array
509        m = self.mask
510        if self.data:
511            tlab = self.data._getsourcename(self._fittedrow)
512            xlab = self.data._getabcissalabel(self._fittedrow)
513            m =  logical_and(self.mask,
514                             array(self.data._getmask(self._fittedrow),
515                                   copy=False))
516                             
517            ylab = self.data._get_ordinate_label()
518
519        colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
520        self._p.palette(0,colours)
521        self._p.set_line(label='Spectrum')
522        y = ma.masked_array(self.y,mask=logical_not(m))
523        self._p.plot(self.x, y)
524        if residual:
525            self._p.palette(1)
526            self._p.set_line(label='Residual')
527            y = ma.masked_array(self.get_residual(),
528                                  mask=logical_not(m))
529            self._p.plot(self.x, y)
530        self._p.palette(2)
531        if components is not None:
532            cs = components
533            if isinstance(components,int): cs = [components]
534            if plotparms:
535                self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
536            n = len(self.components)
537            self._p.palette(3)
538            for c in cs:
539                if 0 <= c < n:
540                    lab = self.fitfuncs[c]+str(c)
541                    self._p.set_line(label=lab)
542                    y = ma.masked_array(self.fitter.evaluate(c),
543                                          mask=logical_not(m))
544
545                    self._p.plot(self.x, y)
546                elif c == -1:
547                    self._p.palette(2)
548                    self._p.set_line(label="Total Fit")
549                    y = ma.masked_array(self.fitter.getfit(),
550                                          mask=logical_not(m))
551                    self._p.plot(self.x, y)
552        else:
553            self._p.palette(2)
554            self._p.set_line(label='Fit')
555            y = ma.masked_array(self.fitter.getfit(),
556                                  mask=logical_not(m))
557            self._p.plot(self.x, y)
558        xlim=[min(self.x),max(self.x)]
559        self._p.axes.set_xlim(xlim)
560        self._p.set_axes('xlabel',xlab)
561        self._p.set_axes('ylabel',ylab)
562        self._p.set_axes('title',tlab)
563        self._p.release()
564        if (not rcParams['plotter.gui']):
565            self._p.save(filename)
566        print_log()
567
568    def auto_fit(self, insitu=None, plot=False):
569        """
570        Return a scan where the function is applied to all rows for
571        all Beams/IFs/Pols.
572
573        """
574        from asap import scantable
575        if not isinstance(self.data, scantable) :
576            msg = "Data is not a scantable"
577            if rcParams['verbose']:
578                print msg
579                return
580            else:
581                raise TypeError(msg)
582        if insitu is None: insitu = rcParams['insitu']
583        if not insitu:
584            scan = self.data.copy()
585        else:
586            scan = self.data
587        rows = xrange(scan.nrow())
588        from asap import asaplog
589        asaplog.push("Fitting:")
590        for r in rows:
591            out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),scan.getbeam(r),scan.getif(r),scan.getpol(r), scan.getcycle(r))
592            asaplog.push(out, False)
593            self.x = scan._getabcissa(r)
594            self.y = scan._getspectrum(r)
595            self.data = None
596            self.fit()
597            x = self.get_parameters()
598            if plot:
599                self.plot(residual=True)
600                x = raw_input("Accept fit ([y]/n): ")
601                if x.upper() == 'N':
602                    continue
603            scan._setspectrum(self.fitter.getresidual(), r)
604        if plot:
605            self._p.unmap()
606            self._p = None
607        print_log()
608        return scan
609
Note: See TracBrowser for help on using the repository browser.