source: branches/alma/python/asapfitter.py @ 1612

Last change on this file since 1612 was 1612, checked in by Takeshi Nakazato, 15 years ago

New Development: No

JIRA Issue: Yes CAS-729, CAS-1147

Ready to Release: Yes

Interface Changes: No

What Interface Changed: Please list interface changes

Test Programs: List test programs

Put in Release Notes: Yes

Module(s): Module Names change impacts.

Description: Describe your changes here...

I have changed that almost all log messages are output to casapy.log,
not to the terminal window. After this change, asap becomes to depend on casapy
and is not running in standalone, because asap have to import taskinit module
to access casalogger.


  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.4 KB
Line 
1import _asap
2from asap import rcParams
3from asap import print_log
4from asap import _n_bools
5from asap import mask_and
6from taskinit import *
7
8class fitter:
9    """
10    The fitting class for ASAP.
11    """
12
13    def __init__(self):
14        """
15        Create a fitter object. No state is set.
16        """
17        self.fitter = _asap.fitter()
18        self.x = None
19        self.y = None
20        self.mask = None
21        self.fitfunc = None
22        self.fitfuncs = None
23        self.fitted = False
24        self.data = None
25        self.components = 0
26        self._fittedrow = 0
27        self._p = None
28        self._selection = None
29        self.uselinear = False
30
31    def set_data(self, xdat, ydat, mask=None):
32        """
33        Set the absissa and ordinate for the fit. Also set the mask
34        indicationg valid points.
35        This can be used for data vectors retrieved from a scantable.
36        For scantable fitting use 'fitter.set_scan(scan, mask)'.
37        Parameters:
38            xdat:    the abcissa values
39            ydat:    the ordinate values
40            mask:    an optional mask
41
42        """
43        self.fitted = False
44        self.x = xdat
45        self.y = ydat
46        if mask == None:
47            self.mask = _n_bools(len(xdat), True)
48        else:
49            self.mask = mask
50        return
51
52    def set_scan(self, thescan=None, mask=None):
53        """
54        Set the 'data' (a scantable) of the fitter.
55        Parameters:
56            thescan:     a scantable
57            mask:        a msk retrieved from the scantable
58        """
59        if not thescan:
60            msg = "Please give a correct scan"
61            if rcParams['verbose']:
62                #print msg
63                casalog.post( msg, 'WARN' )
64                return
65            else:
66                raise TypeError(msg)
67        self.fitted = False
68        self.data = thescan
69        self.mask = None
70        if mask is None:
71            self.mask = _n_bools(self.data.nchan(), True)
72        else:
73            self.mask = mask
74        return
75
76    def set_function(self, **kwargs):
77        """
78        Set the function to be fit.
79        Parameters:
80            poly:    use a polynomial of the order given with nonlinear least squares fit
81            lpoly:   use polynomial of the order given with linear least squares fit
82            gauss:   fit the number of gaussian specified
83        Example:
84            fitter.set_function(gauss=2) # will fit two gaussians
85            fitter.set_function(poly=3)  # will fit a 3rd order polynomial via nonlinear method
86            fitter.set_function(lpoly=3)  # will fit a 3rd order polynomial via linear method
87        """
88        #default poly order 0
89        n=0
90        if kwargs.has_key('poly'):
91            self.fitfunc = 'poly'
92            n = kwargs.get('poly')
93            self.components = [n]
94            self.uselinear = False
95        elif kwargs.has_key('lpoly'):
96            self.fitfunc = 'poly'
97            n = kwargs.get('lpoly')
98            self.components = [n]
99            self.uselinear = True
100        elif kwargs.has_key('gauss'):
101            n = kwargs.get('gauss')
102            self.fitfunc = 'gauss'
103            self.fitfuncs = [ 'gauss' for i in range(n) ]
104            self.components = [ 3 for i in range(n) ]
105            self.uselinear = False
106        else:
107            msg = "Invalid function type."
108            if rcParams['verbose']:
109                #print msg
110                casalog.post( msg, 'WARN' )
111                return
112            else:
113                raise TypeError(msg)
114
115        self.fitter.setexpression(self.fitfunc,n)
116        self.fitted = False
117        return
118
119    def fit(self, row=0, estimate=False):
120        """
121        Execute the actual fitting process. All the state has to be set.
122        Parameters:
123            row:        specify the row in the scantable
124            estimate:   auto-compute an initial parameter set (default False)
125                        This can be used to compute estimates even if fit was
126                        called before.
127        Example:
128            s = scantable('myscan.asap')
129            s.set_cursor(thepol=1)        # select second pol
130            f = fitter()
131            f.set_scan(s)
132            f.set_function(poly=0)
133            f.fit(row=0)                  # fit first row
134        """
135        if ((self.x is None or self.y is None) and self.data is None) \
136               or self.fitfunc is None:
137            msg = "Fitter not yet initialised. Please set data & fit function"
138            if rcParams['verbose']:
139                #print msg
140                casalog.post( msg, 'WARN' )
141                return
142            else:
143                raise RuntimeError(msg)
144
145        else:
146            if self.data is not None:
147                self.x = self.data._getabcissa(row)
148                self.y = self.data._getspectrum(row)
149                self.mask = mask_and(self.mask, self.data._getmask(row))
150                from asap import asaplog
151                asaplog.push("Fitting:")
152                i = row
153                out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),
154                                                                      self.data.getbeam(i),
155                                                                      self.data.getif(i),
156                                                                      self.data.getpol(i),
157                                                                      self.data.getcycle(i))
158                asaplog.push(out,False)
159        self.fitter.setdata(self.x, self.y, self.mask)
160        if self.fitfunc == 'gauss':
161            ps = self.fitter.getparameters()
162            if len(ps) == 0 or estimate:
163                self.fitter.estimate()
164        try:
165            fxdpar = list(self.fitter.getfixedparameters())
166            if len(fxdpar) and fxdpar.count(0) == 0:
167                 raise RuntimeError,"No point fitting, if all parameters are fixed."
168            if self.uselinear:
169                converged = self.fitter.lfit()
170            else:
171                converged = self.fitter.fit()
172            if not converged:
173                raise RuntimeError,"Fit didn't converge."
174        except RuntimeError, msg:
175            if rcParams['verbose']:
176                #print msg
177                casalog.post( msg, 'WARN' )
178            else:
179                raise
180        self._fittedrow = row
181        self.fitted = True
182        print_log()
183        return
184
185    def store_fit(self, filename=None):
186        """
187        Save the fit parameters.
188        Parameters:
189            filename:    if specified save as an ASCII file, if None (default)
190                         store it in the scnatable
191        """
192        if self.fitted and self.data is not None:
193            pars = list(self.fitter.getparameters())
194            fixed = list(self.fitter.getfixedparameters())
195            from asap.asapfit import asapfit
196            fit = asapfit()
197            fit.setparameters(pars)
198            fit.setfixedparameters(fixed)
199            fit.setfunctions(self.fitfuncs)
200            fit.setcomponents(self.components)
201            fit.setframeinfo(self.data._getcoordinfo())
202            if filename is not None:
203                import os
204                filename = os.path.expandvars(os.path.expanduser(filename))
205                if os.path.exists(filename):
206                    raise IOError("File '%s' exists." % filename)
207                fit.save(filename)
208            else:
209                self.data._addfit(fit,self._fittedrow)
210
211    #def set_parameters(self, params, fixed=None, component=None):
212    def set_parameters(self,*args,**kwargs):
213        """
214        Set the parameters to be fitted.
215        Parameters:
216              params:    a vector of parameters
217              fixed:     a vector of which parameters are to be held fixed
218                         (default is none)
219              component: in case of multiple gaussians, the index of the
220                         component
221        """
222        component = None
223        fixed = None
224        params = None
225
226        if len(args) and isinstance(args[0],dict):
227            kwargs = args[0]
228        if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
229        if kwargs.has_key("params"): params = kwargs["params"]
230        if len(args) == 2 and isinstance(args[1], int):
231            component = args[1]
232        if self.fitfunc is None:
233            msg = "Please specify a fitting function first."
234            if rcParams['verbose']:
235                #print msg
236                casalog.post( msg, 'WARN' )
237                return
238            else:
239                raise RuntimeError(msg)
240        if self.fitfunc == "gauss" and component is not None:
241            if not self.fitted and sum(self.fitter.getparameters()) == 0:
242                pars = _n_bools(len(self.components)*3, False)
243                fxd = _n_bools(len(pars), False)
244            else:
245                pars = list(self.fitter.getparameters())
246                fxd = list(self.fitter.getfixedparameters())
247            i = 3*component
248            pars[i:i+3] = params
249            fxd[i:i+3] = fixed
250            params = pars
251            fixed = fxd
252        self.fitter.setparameters(params)
253        if fixed is not None:
254            self.fitter.setfixedparameters(fixed)
255        print_log()
256        return
257
258    def set_gauss_parameters(self, peak, centre, fwhm,
259                             peakfixed=0, centrefixed=0,
260                             fwhmfixed=0,
261                             component=0):
262        """
263        Set the Parameters of a 'Gaussian' component, set with set_function.
264        Parameters:
265            peak, centre, fwhm:  The gaussian parameters
266            peakfixed,
267            centrefixed,
268            fwhmfixed:           Optional parameters to indicate if
269                                 the paramters should be held fixed during
270                                 the fitting process. The default is to keep
271                                 all parameters flexible.
272            component:           The number of the component (Default is the
273                                 component 0)
274        """
275        if self.fitfunc != "gauss":
276            msg = "Function only operates on Gaussian components."
277            if rcParams['verbose']:
278                #print msg
279                casalog.post( msg, 'WARN' )
280                return
281            else:
282                raise ValueError(msg)
283        if 0 <= component < len(self.components):
284            d = {'params':[peak, centre, fwhm],
285                 'fixed':[peakfixed, centrefixed, fwhmfixed]}
286            self.set_parameters(d, component)
287        else:
288            msg = "Please select a valid  component."
289            if rcParams['verbose']:
290                #print msg
291                casalog.post( msg, 'WARN' )
292                return
293            else:
294                raise ValueError(msg)
295
296    def get_area(self, component=None):
297        """
298        Return the area under the fitted gaussian component.
299        Parameters:
300              component:   the gaussian component selection,
301                           default (None) is the sum of all components
302        Note:
303              This will only work for gaussian fits.
304        """
305        if not self.fitted: return
306        if self.fitfunc == "gauss":
307            pars = list(self.fitter.getparameters())
308            from math import log,pi,sqrt
309            fac = sqrt(pi/log(16.0))
310            areas = []
311            for i in range(len(self.components)):
312                j = i*3
313                cpars = pars[j:j+3]
314                areas.append(fac * cpars[0] * cpars[2])
315        else:
316            return None
317        if component is not None:
318            return areas[component]
319        else:
320            return sum(areas)
321
322    def get_errors(self, component=None):
323        """
324        Return the errors in the parameters.
325        Parameters:
326            component:    get the errors for the specified component
327                          only, default is all components
328        """
329        if not self.fitted:
330            msg = "Not yet fitted."
331            if rcParams['verbose']:
332                #print msg
333                casalog.post( msg, 'WARN' )
334                return
335            else:
336                raise RuntimeError(msg)
337        errs = list(self.fitter.geterrors())
338        cerrs = errs
339        if component is not None:
340            if self.fitfunc == "gauss":
341                i = 3*component
342                if i < len(errs):
343                    cerrs = errs[i:i+3]
344        return cerrs
345
346    def get_parameters(self, component=None, errors=False):
347        """
348        Return the fit paramters.
349        Parameters:
350             component:    get the parameters for the specified component
351                           only, default is all components
352        """
353        if not self.fitted:
354            msg = "Not yet fitted."
355            if rcParams['verbose']:
356                #print msg
357                casalog.post( msg, 'WARN' )
358                return
359            else:
360                raise RuntimeError(msg)
361        pars = list(self.fitter.getparameters())
362        fixed = list(self.fitter.getfixedparameters())
363        errs = list(self.fitter.geterrors())
364        area = []
365        if component is not None:
366            if self.fitfunc == "gauss":
367                i = 3*component
368                cpars = pars[i:i+3]
369                cfixed = fixed[i:i+3]
370                cerrs = errs[i:i+3]
371                a = self.get_area(component)
372                area = [a for i in range(3)]
373            else:
374                cpars = pars
375                cfixed = fixed
376                cerrs = errs
377        else:
378            cpars = pars
379            cfixed = fixed
380            cerrs = errs
381            if self.fitfunc == "gauss":
382                for c in range(len(self.components)):
383                  a = self.get_area(c)
384                  area += [a for i in range(3)]
385        fpars = self._format_pars(cpars, cfixed, errors and cerrs, area)
386        if rcParams['verbose']:
387            #print fpars
388            casalog.post( fpars )
389        return {'params':cpars, 'fixed':cfixed, 'formatted': fpars,
390                'errors':cerrs}
391
392    def _format_pars(self, pars, fixed, errors, area):
393        out = ''
394        if self.fitfunc == 'poly':
395            c = 0
396            for i in range(len(pars)):
397                fix = ""
398                if len(fixed) and fixed[i]: fix = "(fixed)"
399                if errors :
400                    out += '  p%d%s= %3.6f (%1.6f),' % (c,fix,pars[i], errors[i])
401                else:
402                    out += '  p%d%s= %3.6f,' % (c,fix,pars[i])
403                c+=1
404            out = out[:-1]  # remove trailing ','
405        elif self.fitfunc == 'gauss':
406            i = 0
407            c = 0
408            aunit = ''
409            ounit = ''
410            if self.data:
411                aunit = self.data.get_unit()
412                ounit = self.data.get_fluxunit()
413            while i < len(pars):
414                if len(area):
415                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n      area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
416                else:
417                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
418                c+=1
419                i+=3
420        return out
421
422    def get_estimate(self):
423        """
424        Return the parameter estimates (for non-linear functions).
425        """
426        pars = self.fitter.getestimate()
427        fixed = self.fitter.getfixedparameters()
428        if rcParams['verbose']:
429            #print self._format_pars(pars,fixed,None)
430            casalog.post( self._format_pars(pars,fixed,None) )
431        return pars
432
433    def get_residual(self):
434        """
435        Return the residual of the fit.
436        """
437        if not self.fitted:
438            msg = "Not yet fitted."
439            if rcParams['verbose']:
440                #print msg
441                casalog.post( msg, 'WARN' )
442                return
443            else:
444                raise RuntimeError(msg)
445        return self.fitter.getresidual()
446
447    def get_chi2(self):
448        """
449        Return chi^2.
450        """
451        if not self.fitted:
452            msg = "Not yet fitted."
453            if rcParams['verbose']:
454                #print msg
455                casalog.post( msg, 'WARN' )
456                return
457            else:
458                raise RuntimeError(msg)
459        ch2 = self.fitter.getchi2()
460        if rcParams['verbose']:
461            #print 'Chi^2 = %3.3f' % (ch2)
462            casalog.post( 'Chi^2 = %3.3f' % (ch2) )
463        return ch2
464
465    def get_fit(self):
466        """
467        Return the fitted ordinate values.
468        """
469        if not self.fitted:
470            msg = "Not yet fitted."
471            if rcParams['verbose']:
472                #print msg
473                casalog.post( msg, 'WARN' )
474                return
475            else:
476                raise RuntimeError(msg)
477        return self.fitter.getfit()
478
479    def commit(self):
480        """
481        Return a new scan where the fits have been commited (subtracted)
482        """
483        if not self.fitted:
484            msg = "Not yet fitted."
485            if rcParams['verbose']:
486                #print msg
487                casalog.post( msg, 'WARN' )
488                return
489            else:
490                raise RuntimeError(msg)
491        from asap import scantable
492        if not isinstance(self.data, scantable):
493            msg = "Not a scantable"
494            if rcParams['verbose']:
495                #print msg
496                casalog.post( msg, 'WARN' )
497                return
498            else:
499                raise TypeError(msg)
500        scan = self.data.copy()
501        scan._setspectrum(self.fitter.getresidual())
502        print_log()
503        return scan
504
505    def plot(self, residual=False, components=None, plotparms=False, filename=None):
506        """
507        Plot the last fit.
508        Parameters:
509            residual:    an optional parameter indicating if the residual
510                         should be plotted (default 'False')
511            components:  a list of components to plot, e.g [0,1],
512                         -1 plots the total fit. Default is to only
513                         plot the total fit.
514            plotparms:   Inidicates if the parameter values should be present
515                         on the plot
516        """
517        if not self.fitted:
518            return
519        if not self._p or self._p.is_dead:
520            if rcParams['plotter.gui']:
521                from asap.asaplotgui import asaplotgui as asaplot
522            else:
523                from asap.asaplot import asaplot
524            self._p = asaplot()
525        self._p.hold()
526        self._p.clear()
527        self._p.set_panels()
528        self._p.palette(0)
529        tlab = 'Spectrum'
530        xlab = 'Abcissa'
531        ylab = 'Ordinate'
532        from matplotlib.numerix import ma,logical_not,logical_and,array
533        m = self.mask
534        if self.data:
535            tlab = self.data._getsourcename(self._fittedrow)
536            xlab = self.data._getabcissalabel(self._fittedrow)
537            m =  logical_and(self.mask,
538                             array(self.data._getmask(self._fittedrow),
539                                   copy=False))
540                             
541            ylab = self.data._get_ordinate_label()
542
543        colours = ["#777777","#dddddd","red","orange","purple","green","magenta", "cyan"]
544        nomask=True
545        for i in range(len(m)):
546            nomask = nomask and m[i]
547        label0='Masked Region'
548        label1='Spectrum'
549        if ( nomask ):
550            label0=label1
551        else:
552            y = ma.masked_array( self.y, mask = m )
553            self._p.palette(1,colours)
554            self._p.set_line( label = label1 )
555            self._p.plot( self.x, y )
556        self._p.palette(0,colours)
557        self._p.set_line(label=label0)
558        y = ma.masked_array(self.y,mask=logical_not(m))
559        self._p.plot(self.x, y)
560        if residual:
561            self._p.palette(7)
562            self._p.set_line(label='Residual')
563            y = ma.masked_array(self.get_residual(),
564                                  mask=logical_not(m))
565            self._p.plot(self.x, y)
566        self._p.palette(2)
567        if components is not None:
568            cs = components
569            if isinstance(components,int): cs = [components]
570            if plotparms:
571                self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
572            n = len(self.components)
573            self._p.palette(3)
574            for c in cs:
575                if 0 <= c < n:
576                    lab = self.fitfuncs[c]+str(c)
577                    self._p.set_line(label=lab)
578                    y = ma.masked_array(self.fitter.evaluate(c),
579                                          mask=logical_not(m))
580
581                    self._p.plot(self.x, y)
582                elif c == -1:
583                    self._p.palette(2)
584                    self._p.set_line(label="Total Fit")
585                    y = ma.masked_array(self.fitter.getfit(),
586                                          mask=logical_not(m))
587                    self._p.plot(self.x, y)
588        else:
589            self._p.palette(2)
590            self._p.set_line(label='Fit')
591            y = ma.masked_array(self.fitter.getfit(),
592                                  mask=logical_not(m))
593            self._p.plot(self.x, y)
594        xlim=[min(self.x),max(self.x)]
595        self._p.axes.set_xlim(xlim)
596        self._p.set_axes('xlabel',xlab)
597        self._p.set_axes('ylabel',ylab)
598        self._p.set_axes('title',tlab)
599        self._p.release()
600        if (not rcParams['plotter.gui']):
601            self._p.save(filename)
602        print_log()
603
604    def auto_fit(self, insitu=None, plot=False):
605        """
606        Return a scan where the function is applied to all rows for
607        all Beams/IFs/Pols.
608
609        """
610        from asap import scantable
611        if not isinstance(self.data, scantable) :
612            msg = "Data is not a scantable"
613            if rcParams['verbose']:
614                #print msg
615                casalog.post( msg, 'WARN' )
616                return
617            else:
618                raise TypeError(msg)
619        if insitu is None: insitu = rcParams['insitu']
620        if not insitu:
621            scan = self.data.copy()
622        else:
623            scan = self.data
624        rows = xrange(scan.nrow())
625        # Save parameters of baseline fits as a class attribute.
626        # NOTICE: This does not reflect changes in scantable!
627        if len(rows) > 0: self.blpars=[]
628        from asap import asaplog
629        asaplog.push("Fitting:")
630        for r in rows:
631            out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),
632                                                                   scan.getbeam(r),
633                                                                   scan.getif(r),
634                                                                   scan.getpol(r),
635                                                                   scan.getcycle(r))
636            asaplog.push(out, False)
637            self.x = scan._getabcissa(r)
638            self.y = scan._getspectrum(r)
639            self.mask = mask_and(self.mask, scan._getmask(r))
640            self.data = None
641            self.fit()
642            x = self.get_parameters()
643            fpar = self.get_parameters()
644            if plot:
645                self.plot(residual=True)
646                x = raw_input("Accept fit ([y]/n): ")
647                if x.upper() == 'N':
648                    self.blpars.append(None)
649                    continue
650            scan._setspectrum(self.fitter.getresidual(), r)
651            self.blpars.append(fpar)
652        if plot:
653            self._p.unmap()
654            self._p = None
655        print_log()
656        return scan
657
Note: See TracBrowser for help on using the repository browser.