source: trunk/python/asapfitter.py @ 1088

Last change on this file since 1088 was 1088, checked in by mar637, 18 years ago

use MA instead of spectrum and mask for plotting. THIS isn't tested yet. printing out errors for poly coeffs. need to do gauss

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 20.1 KB
Line 
1import _asap
2from asap import rcParams
3from asap import print_log
4
5class fitter:
6    """
7    The fitting class for ASAP.
8    """
9
10    def __init__(self):
11        """
12        Create a fitter object. No state is set.
13        """
14        self.fitter = _asap.fitter()
15        self.x = None
16        self.y = None
17        self.mask = None
18        self.fitfunc = None
19        self.fitfuncs = None
20        self.fitted = False
21        self.data = None
22        self.components = 0
23        self._fittedrow = 0
24        self._p = None
25        self._selection = None
26
27    def set_data(self, xdat, ydat, mask=None):
28        """
29        Set the absissa and ordinate for the fit. Also set the mask
30        indicationg valid points.
31        This can be used for data vectors retrieved from a scantable.
32        For scantable fitting use 'fitter.set_scan(scan, mask)'.
33        Parameters:
34            xdat:    the abcissa values
35            ydat:    the ordinate values
36            mask:    an optional mask
37
38        """
39        self.fitted = False
40        self.x = xdat
41        self.y = ydat
42        if mask == None:
43            from numarray import ones
44            self.mask = ones(len(xdat))
45        else:
46            self.mask = mask
47        return
48
49    def set_scan(self, thescan=None, mask=None):
50        """
51        Set the 'data' (a scantable) of the fitter.
52        Parameters:
53            thescan:     a scantable
54            mask:        a msk retireved from the scantable
55        """
56        if not thescan:
57            msg = "Please give a correct scan"
58            if rcParams['verbose']:
59                print msg
60                return
61            else:
62                raise TypeError(msg)
63        self.fitted = False
64        self.data = thescan
65        self.mask = None
66        if mask is None:
67            from numarray import ones
68            self.mask = ones(self.data.nchan())
69        else:
70            self.mask = mask
71        return
72
73    def set_function(self, **kwargs):
74        """
75        Set the function to be fit.
76        Parameters:
77            poly:    use a polynomial of the order given
78            gauss:   fit the number of gaussian specified
79        Example:
80            fitter.set_function(gauss=2) # will fit two gaussians
81            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
82        """
83        #default poly order 0
84        n=0
85        if kwargs.has_key('poly'):
86            self.fitfunc = 'poly'
87            n = kwargs.get('poly')
88            self.components = [n]
89        elif kwargs.has_key('gauss'):
90            n = kwargs.get('gauss')
91            self.fitfunc = 'gauss'
92            self.fitfuncs = [ 'gauss' for i in range(n) ]
93            self.components = [ 3 for i in range(n) ]
94        else:
95            msg = "Invalid function type."
96            if rcParams['verbose']:
97                print msg
98                return
99            else:
100                raise TypeError(msg)
101
102        self.fitter.setexpression(self.fitfunc,n)
103        return
104
105    def fit(self, row=0, estimate=False):
106        """
107        Execute the actual fitting process. All the state has to be set.
108        Parameters:
109            row:        specify the row in the scantable
110            estimate:   auto-compute an initial parameter set (default False)
111                        This can be used to compute estimates even if fit was
112                        called before.
113        Example:
114            s = scantable('myscan.asap')
115            s.set_cursor(thepol=1)        # select second pol
116            f = fitter()
117            f.set_scan(s)
118            f.set_function(poly=0)
119            f.fit(row=0)                  # fit first row
120        """
121        if ((self.x is None or self.y is None) and self.data is None) \
122               or self.fitfunc is None:
123            msg = "Fitter not yet initialised. Please set data & fit function"
124            if rcParams['verbose']:
125                print msg
126                return
127            else:
128                raise RuntimeError(msg)
129
130        else:
131            if self.data is not None:
132                self.x = self.data._getabcissa(row)
133                self.y = self.data._getspectrum(row)
134                from asap import asaplog
135                asaplog.push("Fitting:")
136                i = row
137                out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),self.data.getbeam(i),self.data.getif(i),self.data.getpol(i), self.data.getcycle(i))
138                asaplog.push(out,False)
139        self.fitter.setdata(self.x, self.y, self.mask)
140        if self.fitfunc == 'gauss':
141            ps = self.fitter.getparameters()
142            if len(ps) == 0 or estimate:
143                self.fitter.estimate()
144        try:
145            converged = self.fitter.fit()
146            if not converged:
147                raise RuntimeError,"Fit didn't converge."
148        except RuntimeError, msg:
149            if rcParams['verbose']:
150                print msg
151            else:
152                raise
153        self._fittedrow = row
154        self.fitted = True
155        print_log()
156        return
157
158    def store_fit(self):
159        """
160        Store the fit parameters in the scantable.
161        """
162        if self.fitted and self.data is not None:
163            pars = list(self.fitter.getparameters())
164            fixed = list(self.fitter.getfixedparameters())
165            from asap.asapfit import asapfit
166            fit = asapfit()
167            fit.setparameters(pars)
168            fit.setfixedparameters(fixed)
169            fit.setfunctions(self.fitfuncs)
170            fit.setcomponents(self.components)
171            fit.setframeinfo(self.data._getcoordinfo())
172            self.data._addfit(fit,self._fittedrow)
173
174    #def set_parameters(self, params, fixed=None, component=None):
175    def set_parameters(self,*args,**kwargs):
176        """
177        Set the parameters to be fitted.
178        Parameters:
179              params:    a vector of parameters
180              fixed:     a vector of which parameters are to be held fixed
181                         (default is none)
182              component: in case of multiple gaussians, the index of the
183                         component
184        """
185        component = None
186        fixed = None
187        params = None
188
189        if len(args) and isinstance(args[0],dict):
190            kwargs = args[0]
191        if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
192        if kwargs.has_key("params"): params = kwargs["params"]
193        if len(args) == 2 and isinstance(args[1], int):
194            component = args[1]
195        if self.fitfunc is None:
196            msg = "Please specify a fitting function first."
197            if rcParams['verbose']:
198                print msg
199                return
200            else:
201                raise RuntimeError(msg)
202        if self.fitfunc == "gauss" and component is not None:
203            if not self.fitted and sum(self.fitter.getparameters()) == 0:
204                from numarray import zeros
205                pars = list(zeros(len(self.components)*3))
206                fxd = list(zeros(len(pars)))
207            else:
208                pars = list(self.fitter.getparameters())
209                fxd = list(self.fitter.getfixedparameters())
210            i = 3*component
211            pars[i:i+3] = params
212            fxd[i:i+3] = fixed
213            params = pars
214            fixed = fxd
215        self.fitter.setparameters(params)
216        if fixed is not None:
217            self.fitter.setfixedparameters(fixed)
218        print_log()
219        return
220
221    def set_gauss_parameters(self, peak, centre, fhwm,
222                             peakfixed=0, centerfixed=0,
223                             fhwmfixed=0,
224                             component=0):
225        """
226        Set the Parameters of a 'Gaussian' component, set with set_function.
227        Parameters:
228            peak, centre, fhwm:  The gaussian parameters
229            peakfixed,
230            centerfixed,
231            fhwmfixed:           Optional parameters to indicate if
232                                 the paramters should be held fixed during
233                                 the fitting process. The default is to keep
234                                 all parameters flexible.
235            component:           The number of the component (Default is the
236                                 component 0)
237        """
238        if self.fitfunc != "gauss":
239            msg = "Function only operates on Gaussian components."
240            if rcParams['verbose']:
241                print msg
242                return
243            else:
244                raise ValueError(msg)
245        if 0 <= component < len(self.components):
246            d = {'params':[peak, centre, fhwm],
247                 'fixed':[peakfixed, centerfixed, fhwmfixed]}
248            self.set_parameters(d, component)
249        else:
250            msg = "Please select a valid  component."
251            if rcParams['verbose']:
252                print msg
253                return
254            else:
255                raise ValueError(msg)
256
257    def get_area(self, component=None):
258        """
259        Return the area under the fitted gaussian component.
260        Parameters:
261              component:   the gaussian component selection,
262                           default (None) is the sum of all components
263        Note:
264              This will only work for gaussian fits.
265        """
266        if not self.fitted: return
267        if self.fitfunc == "gauss":
268            pars = list(self.fitter.getparameters())
269            from math import log,pi,sqrt
270            fac = sqrt(pi/log(16.0))
271            areas = []
272            for i in range(len(self.components)):
273                j = i*3
274                cpars = pars[j:j+3]
275                areas.append(fac * cpars[0] * cpars[2])
276        else:
277            return None
278        if component is not None:
279            return areas[component]
280        else:
281            return sum(areas)
282
283    def get_errors(self, component=None):
284        """
285        Return the errors in the parameters.
286        Parameters:
287            component:    get the errors for the specified component
288                          only, default is all components
289        """
290        if not self.fitted:
291            msg = "Not yet fitted."
292            if rcParams['verbose']:
293                print msg
294                return
295            else:
296                raise RuntimeError(msg)
297        errs = list(self.fitter.geterrors())
298        cerrs = errs
299        if component is not None:
300            if self.fitfunc == "gauss":
301                i = 3*component
302                if i < len(errs):
303                    cerrs = errs[i:i+3]
304        return cerrs
305
306    def get_parameters(self, component=None, errors=False):
307        """
308        Return the fit paramters.
309        Parameters:
310             component:    get the parameters for the specified component
311                           only, default is all components
312        """
313        if not self.fitted:
314            msg = "Not yet fitted."
315            if rcParams['verbose']:
316                print msg
317                return
318            else:
319                raise RuntimeError(msg)
320        pars = list(self.fitter.getparameters())
321        fixed = list(self.fitter.getfixedparameters())
322        errs = list(self.fitter.geterrors())
323        area = []
324        if component is not None:
325            if self.fitfunc == "gauss":
326                i = 3*component
327                cpars = pars[i:i+3]
328                cfixed = fixed[i:i+3]
329                cerrs = errs[i:i+3]
330                a = self.get_area(component)
331                area = [a for i in range(3)]
332            else:
333                cpars = pars
334                cfixed = fixed
335                cerrs = errs
336        else:
337            cpars = pars
338            cfixed = fixed
339            cerrs = errs
340            if self.fitfunc == "gauss":
341                for c in range(len(self.components)):
342                  a = self.get_area(c)
343                  area += [a for i in range(3)]
344        fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
345        if rcParams['verbose']:
346            print fpars
347        return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
348                'errors':cerrs}
349
350    def _format_pars(self, pars, fixed, errors, area):
351        out = ''
352        if self.fitfunc == 'poly':
353            c = 0
354            for i in range(len(pars)):
355                fix = ""
356                if fixed[i]: fix = "(fixed)"
357                if errors :
358                    out += '  p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
359                else:
360                    out += '  p%d%s= %3.6f,' % (c,fix,pars[i])
361                c+=1
362            out = out[:-1]  # remove trailing ','
363        elif self.fitfunc == 'gauss':
364            i = 0
365            c = 0
366            aunit = ''
367            ounit = ''
368            if self.data:
369                aunit = self.data.get_unit()
370                ounit = self.data.get_fluxunit()
371            while i < len(pars):
372                if len(area):
373                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n      area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
374                else:
375                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
376                c+=1
377                i+=3
378        return out
379
380    def get_estimate(self):
381        """
382        Return the parameter estimates (for non-linear functions).
383        """
384        pars = self.fitter.getestimate()
385        fixed = self.fitter.getfixedparameters()
386        if rcParams['verbose']:
387            print self._format_pars(pars,fixed,None)
388        return pars
389
390    def get_residual(self):
391        """
392        Return the residual of the fit.
393        """
394        if not self.fitted:
395            msg = "Not yet fitted."
396            if rcParams['verbose']:
397                print msg
398                return
399            else:
400                raise RuntimeError(msg)
401        return self.fitter.getresidual()
402
403    def get_chi2(self):
404        """
405        Return chi^2.
406        """
407        if not self.fitted:
408            msg = "Not yet fitted."
409            if rcParams['verbose']:
410                print msg
411                return
412            else:
413                raise RuntimeError(msg)
414        ch2 = self.fitter.getchi2()
415        if rcParams['verbose']:
416            print 'Chi^2 = %3.3f' % (ch2)
417        return ch2
418
419    def get_fit(self):
420        """
421        Return the fitted ordinate values.
422        """
423        if not self.fitted:
424            msg = "Not yet fitted."
425            if rcParams['verbose']:
426                print msg
427                return
428            else:
429                raise RuntimeError(msg)
430        return self.fitter.getfit()
431
432    def commit(self):
433        """
434        Return a new scan where the fits have been commited (subtracted)
435        """
436        if not self.fitted:
437            print "Not yet fitted."
438            msg = "Not yet fitted."
439            if rcParams['verbose']:
440                print msg
441                return
442            else:
443                raise RuntimeError(msg)
444        from asap import scantable
445        if not isinstance(self.data, scantable):
446            msg = "Not a scantable"
447            if rcParams['verbose']:
448                print msg
449                return
450            else:
451                raise TypeError(msg)
452        scan = self.data.copy()
453        scan._setspectrum(self.fitter.getresidual())
454        print_log()
455
456    def plot(self, residual=False, components=None, plotparms=False, filename=None):
457        """
458        Plot the last fit.
459        Parameters:
460            residual:    an optional parameter indicating if the residual
461                         should be plotted (default 'False')
462            components:  a list of components to plot, e.g [0,1],
463                         -1 plots the total fit. Default is to only
464                         plot the total fit.
465            plotparms:   Inidicates if the parameter values should be present
466                         on the plot
467        """
468        if not self.fitted:
469            return
470        if not self._p or self._p.is_dead:
471            if rcParams['plotter.gui']:
472                from asap.asaplotgui import asaplotgui as asaplot
473            else:
474                from asap.asaplot import asaplot
475            self._p = asaplot()
476        self._p.hold()
477        self._p.clear()
478        self._p.set_panels()
479        self._p.palette(0)
480        tlab = 'Spectrum'
481        xlab = 'Abcissa'
482        ylab = 'Ordinate'
483        m = None
484        if self.data:
485            tlab = self.data._getsourcename(self._fittedrow)
486            xlab = self.data._getabcissalabel(self._fittedrow)
487            m = self.data._getmask(self._fittedrow)
488            ylab = self.data._get_ordinate_label()
489
490        colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
491        self._p.palette(0,colours)
492        self._p.set_line(label='Spectrum')
493        from matplotlib.numerix import ma,logical_not,array
494        y = ma.MA.MaskedArray(self.y,mask=logical_not(array(m,copy=0)),copy=0)
495        self._p.plot(self.x, y)
496        if residual:
497            self._p.palette(1)
498            self._p.set_line(label='Residual')
499            y = ma.MA.MaskedArray(self.get_residual(),
500                                  mask=logical_not(array(m,copy=0)),copy=0)
501            self._p.plot(self.x, y)
502        self._p.palette(2)
503        if components is not None:
504            cs = components
505            if isinstance(components,int): cs = [components]
506            if plotparms:
507                self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
508            n = len(self.components)
509            self._p.palette(3)
510            for c in cs:
511                if 0 <= c < n:
512                    lab = self.fitfuncs[c]+str(c)
513                    self._p.set_line(label=lab)
514                    y = ma.MA.MaskedArray(self.fitter.evaluate(c),
515                                          mask=logical_not(array(m,copy=0)),
516                                          copy=0)
517
518                    self._p.plot(self.x, y)
519                elif c == -1:
520                    self._p.palette(2)
521                    self._p.set_line(label="Total Fit")
522                    y = ma.MA.MaskedArray(self.fitter.get_fit(),
523                                          mask=logical_not(array(m,copy=0)),
524                                          copy=0)                   
525                    self._p.plot(self.x, y)
526        else:
527            self._p.palette(2)
528            self._p.set_line(label='Fit')
529            y = ma.MA.MaskedArray(self.fitter.get_fit(),
530                                  mask=logical_not(array(m,copy=0)),
531                                  copy=0)                   
532            self._p.plot(self.x, y)
533        xlim=[min(self.x),max(self.x)]
534        self._p.axes.set_xlim(xlim)
535        self._p.set_axes('xlabel',xlab)
536        self._p.set_axes('ylabel',ylab)
537        self._p.set_axes('title',tlab)
538        self._p.release()
539        if (not rcParams['plotter.gui']):
540            self._p.save(filename)
541        print_log()
542
543    def auto_fit(self, insitu=None, plot=False):
544        """
545        Return a scan where the function is applied to all rows for
546        all Beams/IFs/Pols.
547
548        """
549        from asap import scantable
550        if not isinstance(self.data, scantable) :
551            msg = "Data is not a scantable"
552            if rcParams['verbose']:
553                print msg
554                return
555            else:
556                raise TypeError(msg)
557        if insitu is None: insitu = rcParams['insitu']
558        if not insitu:
559            scan = self.data.copy()
560        else:
561            scan = self.data
562        rows = xrange(scan.nrow())
563        from asap import asaplog
564        asaplog.push("Fitting:")
565        for r in rows:
566            out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),scan.getbeam(r),scan.getif(r),scan.getpol(r), scan.getcycle(r))
567            asaplog.push(out, False)
568            self.x = scan._getabcissa(r)
569            self.y = scan._getspectrum(r)
570            self.data = None
571            self.fit()
572            x = self.get_parameters()
573            if plot:
574                self.plot(residual=True)
575                x = raw_input("Accept fit ([y]/n): ")
576                if x.upper() == 'N':
577                    continue
578            scan._setspectrum(self.fitter.getresidual(), r)
579        if plot:
580            self._p.unmap()
581            self._p = None
582        print_log()
583        return scan
584
Note: See TracBrowser for help on using the repository browser.