source: branches/asap4casa3.1.0/python/asapfitter.py @ 1935

Last change on this file since 1935 was 1935, checked in by Takeshi Nakazato, 14 years ago

New Development: No

JIRA Issue: No

Ready for Test: Yes

Interface Changes: No

What Interface Changed: Please list interface changes

Test Programs: List test programs

Put in Release Notes: Yes/No?

Module(s): Module Names change impacts.

Description: Describe your changes here...

Minor bug fix to work asapfitter.store_fit() on polynomial fitting.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.0 KB
Line 
1import _asap
2from asap.parameters import rcParams
3from asap.logging import asaplog, asaplog_post_dec
4from asap.utils import _n_bools, mask_and
5
6
7class fitter:
8    """
9    The fitting class for ASAP.
10    """
11    def __init__(self):
12        """
13        Create a fitter object. No state is set.
14        """
15        self.fitter = _asap.fitter()
16        self.x = None
17        self.y = None
18        self.mask = None
19        self.fitfunc = None
20        self.fitfuncs = None
21        self.fitted = False
22        self.data = None
23        self.components = 0
24        self._fittedrow = 0
25        self._p = None
26        self._selection = None
27        self.uselinear = False
28
29    def set_data(self, xdat, ydat, mask=None):
30        """
31        Set the absissa and ordinate for the fit. Also set the mask
32        indicationg valid points.
33        This can be used for data vectors retrieved from a scantable.
34        For scantable fitting use 'fitter.set_scan(scan, mask)'.
35        Parameters:
36            xdat:    the abcissa values
37            ydat:    the ordinate values
38            mask:    an optional mask
39
40        """
41        self.fitted = False
42        self.x = xdat
43        self.y = ydat
44        if mask == None:
45            self.mask = _n_bools(len(xdat), True)
46        else:
47            self.mask = mask
48        return
49
50    @asaplog_post_dec
51    def set_scan(self, thescan=None, mask=None):
52        """
53        Set the 'data' (a scantable) of the fitter.
54        Parameters:
55            thescan:     a scantable
56            mask:        a msk retrieved from the scantable
57        """
58        if not thescan:
59            msg = "Please give a correct scan"
60            raise TypeError(msg)
61        self.fitted = False
62        self.data = thescan
63        self.mask = None
64        if mask is None:
65            self.mask = _n_bools(self.data.nchan(), True)
66        else:
67            self.mask = mask
68        return
69
70    @asaplog_post_dec
71    def set_function(self, **kwargs):
72        """
73        Set the function to be fit.
74        Parameters:
75            poly:    use a polynomial of the order given with nonlinear least squares fit
76            lpoly:   use polynomial of the order given with linear least squares fit
77            gauss:   fit the number of gaussian specified
78            lorentz: fit the number of lorentzian specified
79        Example:
80            fitter.set_function(poly=3)  # will fit a 3rd order polynomial via nonlinear method
81            fitter.set_function(lpoly=3)  # will fit a 3rd order polynomial via linear method
82            fitter.set_function(gauss=2) # will fit two gaussians
83            fitter.set_function(lorentz=2) # will fit two lorentzians
84        """
85        #default poly order 0
86        n=0
87        if kwargs.has_key('poly'):
88            self.fitfunc = 'poly'
89            self.fitfuncs = ['poly']
90            n = kwargs.get('poly')
91            self.components = [n+1]
92            self.uselinear = False
93        elif kwargs.has_key('lpoly'):
94            self.fitfunc = 'poly'
95            self.fitfuncs = ['lpoly']
96            n = kwargs.get('lpoly')
97            self.components = [n+1]
98            self.uselinear = True
99        elif kwargs.has_key('gauss'):
100            n = kwargs.get('gauss')
101            self.fitfunc = 'gauss'
102            self.fitfuncs = [ 'gauss' for i in range(n) ]
103            self.components = [ 3 for i in range(n) ]
104            self.uselinear = False
105        elif kwargs.has_key('lorentz'):
106            n = kwargs.get('lorentz')
107            self.fitfunc = 'lorentz'
108            self.fitfuncs = [ 'lorentz' for i in range(n) ]
109            self.components = [ 3 for i in range(n) ]
110            self.uselinear = False
111        else:
112            msg = "Invalid function type."
113            raise TypeError(msg)
114
115        self.fitter.setexpression(self.fitfunc,n)
116        self.fitted = False
117        return
118
119    @asaplog_post_dec
120    def fit(self, row=0, estimate=False):
121        """
122        Execute the actual fitting process. All the state has to be set.
123        Parameters:
124            row:        specify the row in the scantable
125            estimate:   auto-compute an initial parameter set (default False)
126                        This can be used to compute estimates even if fit was
127                        called before.
128        Example:
129            s = scantable('myscan.asap')
130            s.set_cursor(thepol=1)        # select second pol
131            f = fitter()
132            f.set_scan(s)
133            f.set_function(poly=0)
134            f.fit(row=0)                  # fit first row
135        """
136        if ((self.x is None or self.y is None) and self.data is None) \
137               or self.fitfunc is None:
138            msg = "Fitter not yet initialised. Please set data & fit function"
139            raise RuntimeError(msg)
140
141        else:
142            if self.data is not None:
143                self.x = self.data._getabcissa(row)
144                self.y = self.data._getspectrum(row)
145                self.mask = mask_and(self.mask, self.data._getmask(row))
146                asaplog.push("Fitting:")
147                i = row
148                out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),
149                                                                      self.data.getbeam(i),
150                                                                      self.data.getif(i),
151                                                                      self.data.getpol(i),
152                                                                      self.data.getcycle(i))
153                asaplog.push(out,False)
154        self.fitter.setdata(self.x, self.y, self.mask)
155        if self.fitfunc == 'gauss' or self.fitfunc == 'lorentz':
156            ps = self.fitter.getparameters()
157            if len(ps) == 0 or estimate:
158                self.fitter.estimate()
159        fxdpar = list(self.fitter.getfixedparameters())
160        if len(fxdpar) and fxdpar.count(0) == 0:
161             raise RuntimeError,"No point fitting, if all parameters are fixed."
162        if self.uselinear:
163            converged = self.fitter.lfit()
164        else:
165            converged = self.fitter.fit()
166        if not converged:
167            raise RuntimeError,"Fit didn't converge."
168        self._fittedrow = row
169        self.fitted = True
170        return
171
172    def store_fit(self, filename=None):
173        """
174        Save the fit parameters.
175        Parameters:
176            filename:    if specified save as an ASCII file, if None (default)
177                         store it in the scnatable
178        """
179        if self.fitted and self.data is not None:
180            pars = list(self.fitter.getparameters())
181            fixed = list(self.fitter.getfixedparameters())
182            from asap.asapfit import asapfit
183            fit = asapfit()
184            fit.setparameters(pars)
185            fit.setfixedparameters(fixed)
186            fit.setfunctions(self.fitfuncs)
187            fit.setcomponents(self.components)
188            fit.setframeinfo(self.data._getcoordinfo())
189            if filename is not None:
190                import os
191                filename = os.path.expandvars(os.path.expanduser(filename))
192                if os.path.exists(filename):
193                    raise IOError("File '%s' exists." % filename)
194                fit.save(filename)
195            else:
196                self.data._addfit(fit,self._fittedrow)
197
198    @asaplog_post_dec
199    def set_parameters(self,*args,**kwargs):
200        """
201        Set the parameters to be fitted.
202        Parameters:
203              params:    a vector of parameters
204              fixed:     a vector of which parameters are to be held fixed
205                         (default is none)
206              component: in case of multiple gaussians/lorentzians,
207                         the index of the component
208        """
209        component = None
210        fixed = None
211        params = None
212
213        if len(args) and isinstance(args[0],dict):
214            kwargs = args[0]
215        if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
216        if kwargs.has_key("params"): params = kwargs["params"]
217        if len(args) == 2 and isinstance(args[1], int):
218            component = args[1]
219        if self.fitfunc is None:
220            msg = "Please specify a fitting function first."
221            raise RuntimeError(msg)
222        if (self.fitfunc == "gauss" or self.fitfunc == 'lorentz') and component is not None:
223            if not self.fitted and sum(self.fitter.getparameters()) == 0:
224                pars = _n_bools(len(self.components)*3, False)
225                fxd = _n_bools(len(pars), False)
226            else:
227                pars = list(self.fitter.getparameters())
228                fxd = list(self.fitter.getfixedparameters())
229            i = 3*component
230            pars[i:i+3] = params
231            fxd[i:i+3] = fixed
232            params = pars
233            fixed = fxd
234        self.fitter.setparameters(params)
235        if fixed is not None:
236            self.fitter.setfixedparameters(fixed)
237        return
238
239    @asaplog_post_dec
240    def set_gauss_parameters(self, peak, centre, fwhm,
241                             peakfixed=0, centrefixed=0,
242                             fwhmfixed=0,
243                             component=0):
244        """
245        Set the Parameters of a 'Gaussian' component, set with set_function.
246        Parameters:
247            peak, centre, fwhm:  The gaussian parameters
248            peakfixed,
249            centrefixed,
250            fwhmfixed:           Optional parameters to indicate if
251                                 the paramters should be held fixed during
252                                 the fitting process. The default is to keep
253                                 all parameters flexible.
254            component:           The number of the component (Default is the
255                                 component 0)
256        """
257        if self.fitfunc != "gauss":
258            msg = "Function only operates on Gaussian components."
259            raise ValueError(msg)
260        if 0 <= component < len(self.components):
261            d = {'params':[peak, centre, fwhm],
262                 'fixed':[peakfixed, centrefixed, fwhmfixed]}
263            self.set_parameters(d, component)
264        else:
265            msg = "Please select a valid  component."
266            raise ValueError(msg)
267
268    @asaplog_post_dec
269    def set_lorentz_parameters(self, peak, centre, fwhm,
270                             peakfixed=0, centrefixed=0,
271                             fwhmfixed=0,
272                             component=0):
273        """
274        Set the Parameters of a 'Lorentzian' component, set with set_function.
275        Parameters:
276            peak, centre, fwhm:  The lorentzian parameters
277            peakfixed,
278            centrefixed,
279            fwhmfixed:           Optional parameters to indicate if
280                                 the paramters should be held fixed during
281                                 the fitting process. The default is to keep
282                                 all parameters flexible.
283            component:           The number of the component (Default is the
284                                 component 0)
285        """
286        if self.fitfunc != "lorentz":
287            msg = "Function only operates on Lorentzian components."
288            raise ValueError(msg)
289        if 0 <= component < len(self.components):
290            d = {'params':[peak, centre, fwhm],
291                 'fixed':[peakfixed, centrefixed, fwhmfixed]}
292            self.set_parameters(d, component)
293        else:
294            msg = "Please select a valid  component."
295            raise ValueError(msg)
296
297    def get_area(self, component=None):
298        """
299        Return the area under the fitted gaussian/lorentzian component.
300        Parameters:
301              component:   the gaussian/lorentzian component selection,
302                           default (None) is the sum of all components
303        Note:
304              This will only work for gaussian/lorentzian fits.
305        """
306        if not self.fitted: return
307        if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
308            pars = list(self.fitter.getparameters())
309            from math import log,pi,sqrt
310            if self.fitfunc == "gauss":
311                fac = sqrt(pi/log(16.0))
312            elif self.fitfunc == "lorentz":
313                fac = pi/2.0
314            areas = []
315            for i in range(len(self.components)):
316                j = i*3
317                cpars = pars[j:j+3]
318                areas.append(fac * cpars[0] * cpars[2])
319        else:
320            return None
321        if component is not None:
322            return areas[component]
323        else:
324            return sum(areas)
325
326    @asaplog_post_dec
327    def get_errors(self, component=None):
328        """
329        Return the errors in the parameters.
330        Parameters:
331            component:    get the errors for the specified component
332                          only, default is all components
333        """
334        if not self.fitted:
335            msg = "Not yet fitted."
336            raise RuntimeError(msg)
337        errs = list(self.fitter.geterrors())
338        cerrs = errs
339        if component is not None:
340            if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
341                i = 3*component
342                if i < len(errs):
343                    cerrs = errs[i:i+3]
344        return cerrs
345
346
347    @asaplog_post_dec
348    def get_parameters(self, component=None, errors=False):
349        """
350        Return the fit paramters.
351        Parameters:
352             component:    get the parameters for the specified component
353                           only, default is all components
354        """
355        if not self.fitted:
356            msg = "Not yet fitted."
357            raise RuntimeError(msg)
358        pars = list(self.fitter.getparameters())
359        fixed = list(self.fitter.getfixedparameters())
360        errs = list(self.fitter.geterrors())
361        area = []
362        if component is not None:
363            if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
364                i = 3*component
365                cpars = pars[i:i+3]
366                cfixed = fixed[i:i+3]
367                cerrs = errs[i:i+3]
368                a = self.get_area(component)
369                area = [a for i in range(3)]
370            else:
371                cpars = pars
372                cfixed = fixed
373                cerrs = errs
374        else:
375            cpars = pars
376            cfixed = fixed
377            cerrs = errs
378            if self.fitfunc == "gauss" or self.fitfunc == "lorentz":
379                for c in range(len(self.components)):
380                  a = self.get_area(c)
381                  area += [a for i in range(3)]
382        fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
383        asaplog.push(fpars)
384        return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
385                'errors':cerrs}
386
387    def _format_pars(self, pars, fixed, errors, area):
388        out = ''
389        if self.fitfunc == 'poly':
390            c = 0
391            for i in range(len(pars)):
392                fix = ""
393                if len(fixed) and fixed[i]: fix = "(fixed)"
394                if errors :
395                    out += '  p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
396                else:
397                    out += '  p%d%s= %3.6f,' % (c,fix,pars[i])
398                c+=1
399            out = out[:-1]  # remove trailing ','
400        elif self.fitfunc == 'gauss' or self.fitfunc == 'lorentz':
401            i = 0
402            c = 0
403            aunit = ''
404            ounit = ''
405            if self.data:
406                aunit = self.data.get_unit()
407                ounit = self.data.get_fluxunit()
408            while i < len(pars):
409                if len(area):
410                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n      area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
411                else:
412                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
413                c+=1
414                i+=3
415        return out
416
417
418    @asaplog_post_dec
419    def get_estimate(self):
420        """
421        Return the parameter estimates (for non-linear functions).
422        """
423        pars = self.fitter.getestimate()
424        fixed = self.fitter.getfixedparameters()
425        asaplog.push(self._format_pars(pars,fixed,None,None))
426        return pars
427
428    @asaplog_post_dec
429    def get_residual(self):
430        """
431        Return the residual of the fit.
432        """
433        if not self.fitted:
434            msg = "Not yet fitted."
435            raise RuntimeError(msg)
436        return self.fitter.getresidual()
437
438    @asaplog_post_dec
439    def get_chi2(self):
440        """
441        Return chi^2.
442        """
443        if not self.fitted:
444            msg = "Not yet fitted."
445            raise RuntimeError(msg)
446        ch2 = self.fitter.getchi2()
447        asaplog.push( 'Chi^2 = %3.3f' % (ch2) )
448        return ch2
449
450    @asaplog_post_dec
451    def get_fit(self):
452        """
453        Return the fitted ordinate values.
454        """
455        if not self.fitted:
456            msg = "Not yet fitted."
457            raise RuntimeError(msg)
458        return self.fitter.getfit()
459
460    @asaplog_post_dec
461    def commit(self):
462        """
463        Return a new scan where the fits have been commited (subtracted)
464        """
465        if not self.fitted:
466            msg = "Not yet fitted."
467            raise RuntimeError(msg)
468        from asap import scantable
469        if not isinstance(self.data, scantable):
470            msg = "Not a scantable"
471            raise TypeError(msg)
472        scan = self.data.copy()
473        scan._setspectrum(self.fitter.getresidual())
474        return scan
475
476    @asaplog_post_dec
477    def plot(self, residual=False, components=None, plotparms=False,
478             filename=None):
479        """
480        Plot the last fit.
481        Parameters:
482            residual:    an optional parameter indicating if the residual
483                         should be plotted (default 'False')
484            components:  a list of components to plot, e.g [0,1],
485                         -1 plots the total fit. Default is to only
486                         plot the total fit.
487            plotparms:   Inidicates if the parameter values should be present
488                         on the plot
489        """
490        if not self.fitted:
491            return
492        if not self._p or self._p.is_dead:
493            if rcParams['plotter.gui']:
494                from asap.asaplotgui import asaplotgui as asaplot
495            else:
496                from asap.asaplot import asaplot
497            self._p = asaplot()
498        self._p.hold()
499        self._p.clear()
500        self._p.set_panels()
501        self._p.palette(0)
502        tlab = 'Spectrum'
503        xlab = 'Abcissa'
504        ylab = 'Ordinate'
505        from numpy import ma,logical_not,logical_and,array
506        m = self.mask
507        if self.data:
508            tlab = self.data._getsourcename(self._fittedrow)
509            xlab = self.data._getabcissalabel(self._fittedrow)
510            m =  logical_and(self.mask,
511                             array(self.data._getmask(self._fittedrow),
512                                   copy=False))
513
514            ylab = self.data._get_ordinate_label()
515
516        colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
517        nomask=True
518        for i in range(len(m)):
519            nomask = nomask and m[i]
520        label0='Masked Region'
521        label1='Spectrum'
522        if ( nomask ):
523            label0=label1
524        else:
525            y = ma.masked_array( self.y, mask = m )
526            self._p.palette(1,colours)
527            self._p.set_line( label = label1 )
528            self._p.plot( self.x, y )
529        self._p.palette(0,colours)
530        self._p.set_line(label=label0)
531        y = ma.masked_array(self.y,mask=logical_not(m))
532        self._p.plot(self.x, y)
533        if residual:
534            self._p.palette(7)
535            self._p.set_line(label='Residual')
536            y = ma.masked_array(self.get_residual(),
537                                  mask=logical_not(m))
538            self._p.plot(self.x, y)
539        self._p.palette(2)
540        if components is not None:
541            cs = components
542            if isinstance(components,int): cs = [components]
543            if plotparms:
544                self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
545            n = len(self.components)
546            self._p.palette(3)
547            for c in cs:
548                if 0 <= c < n:
549                    lab = self.fitfuncs[c]+str(c)
550                    self._p.set_line(label=lab)
551                    y = ma.masked_array(self.fitter.evaluate(c),
552                                          mask=logical_not(m))
553
554                    self._p.plot(self.x, y)
555                elif c == -1:
556                    self._p.palette(2)
557                    self._p.set_line(label="Total Fit")
558                    y = ma.masked_array(self.fitter.getfit(),
559                                          mask=logical_not(m))
560                    self._p.plot(self.x, y)
561        else:
562            self._p.palette(2)
563            self._p.set_line(label='Fit')
564            y = ma.masked_array(self.fitter.getfit(),
565                                  mask=logical_not(m))
566            self._p.plot(self.x, y)
567        xlim=[min(self.x),max(self.x)]
568        self._p.axes.set_xlim(xlim)
569        self._p.set_axes('xlabel',xlab)
570        self._p.set_axes('ylabel',ylab)
571        self._p.set_axes('title',tlab)
572        self._p.release()
573        if (not rcParams['plotter.gui']):
574            self._p.save(filename)
575
576    @asaplog_post_dec
577    def auto_fit(self, insitu=None, plot=False):
578        """
579        Return a scan where the function is applied to all rows for
580        all Beams/IFs/Pols.
581
582        """
583        from asap import scantable
584        if not isinstance(self.data, scantable) :
585            msg = "Data is not a scantable"
586            raise TypeError(msg)
587        if insitu is None: insitu = rcParams['insitu']
588        if not insitu:
589            scan = self.data.copy()
590        else:
591            scan = self.data
592        rows = xrange(scan.nrow())
593        # Save parameters of baseline fits as a class attribute.
594        # NOTICE: This does not reflect changes in scantable!
595        if len(rows) > 0: self.blpars=[]
596        asaplog.push("Fitting:")
597        for r in rows:
598            out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),
599                                                                   scan.getbeam(r),
600                                                                   scan.getif(r),
601                                                                   scan.getpol(r),
602                                                                   scan.getcycle(r))
603            asaplog.push(out, False)
604            self.x = scan._getabcissa(r)
605            self.y = scan._getspectrum(r)
606            self.mask = mask_and(self.mask, scan._getmask(r))
607            self.data = None
608            self.fit()
609            x = self.get_parameters()
610            fpar = self.get_parameters()
611            if plot:
612                self.plot(residual=True)
613                x = raw_input("Accept fit ([y]/n): ")
614                if x.upper() == 'N':
615                    self.blpars.append(None)
616                    continue
617            scan._setspectrum(self.fitter.getresidual(), r)
618            self.blpars.append(fpar)
619        if plot:
620            self._p.unmap()
621            self._p = None
622        return scan
Note: See TracBrowser for help on using the repository browser.