source: branches/Release2.0/python/asapfitter.py @ 2401

Last change on this file since 2401 was 1039, checked in by mar637, 18 years ago

minor fix to printing of areas by component. The sum of all areas was printed instead of the area of the individual components.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 17.7 KB
Line 
1import _asap
2from asap import rcParams
3from asap import print_log
4
5class fitter:
6    """
7    The fitting class for ASAP.
8    """
9
10    def __init__(self):
11        """
12        Create a fitter object. No state is set.
13        """
14        self.fitter = _asap.fitter()
15        self.x = None
16        self.y = None
17        self.mask = None
18        self.fitfunc = None
19        self.fitfuncs = None
20        self.fitted = False
21        self.data = None
22        self.components = 0
23        self._fittedrow = 0
24        self._p = None
25        self._selection = None
26
27    def set_data(self, xdat, ydat, mask=None):
28        """
29        Set the absissa and ordinate for the fit. Also set the mask
30        indicationg valid points.
31        This can be used for data vectors retrieved from a scantable.
32        For scantable fitting use 'fitter.set_scan(scan, mask)'.
33        Parameters:
34            xdat:    the abcissa values
35            ydat:    the ordinate values
36            mask:    an optional mask
37
38        """
39        self.fitted = False
40        self.x = xdat
41        self.y = ydat
42        if mask == None:
43            from numarray import ones
44            self.mask = ones(len(xdat))
45        else:
46            self.mask = mask
47        return
48
49    def set_scan(self, thescan=None, mask=None):
50        """
51        Set the 'data' (a scantable) of the fitter.
52        Parameters:
53            thescan:     a scantable
54            mask:        a msk retireved from the scantable
55        """
56        if not thescan:
57            msg = "Please give a correct scan"
58            if rcParams['verbose']:
59                print msg
60                return
61            else:
62                raise TypeError(msg)
63        self.fitted = False
64        self.data = thescan
65        if mask is None:
66            from numarray import ones
67            self.mask = ones(self.data.nchan())
68        else:
69            self.mask = mask
70        return
71
72    def set_function(self, **kwargs):
73        """
74        Set the function to be fit.
75        Parameters:
76            poly:    use a polynomial of the order given
77            gauss:   fit the number of gaussian specified
78        Example:
79            fitter.set_function(gauss=2) # will fit two gaussians
80            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
81        """
82        #default poly order 0
83        n=0
84        if kwargs.has_key('poly'):
85            self.fitfunc = 'poly'
86            n = kwargs.get('poly')
87            self.components = [n]
88        elif kwargs.has_key('gauss'):
89            n = kwargs.get('gauss')
90            self.fitfunc = 'gauss'
91            self.fitfuncs = [ 'gauss' for i in range(n) ]
92            self.components = [ 3 for i in range(n) ]
93        else:
94            msg = "Invalid function type."
95            if rcParams['verbose']:
96                print msg
97                return
98            else:
99                raise TypeError(msg)
100
101        self.fitter.setexpression(self.fitfunc,n)
102        return
103
104    def fit(self, row=0):
105        """
106        Execute the actual fitting process. All the state has to be set.
107        Parameters:
108            row:    specify the row in the scantable
109        Example:
110            s = scantable('myscan.asap')
111            s.set_cursor(thepol=1)        # select second pol
112            f = fitter()
113            f.set_scan(s)
114            f.set_function(poly=0)
115            f.fit(row=0)                  # fit first row
116        """
117        if ((self.x is None or self.y is None) and self.data is None) \
118               or self.fitfunc is None:
119            msg = "Fitter not yet initialised. Please set data & fit function"
120            if rcParams['verbose']:
121                print msg
122                return
123            else:
124                raise RuntimeError(msg)
125
126        else:
127            if self.data is not None:
128                self.x = self.data._getabcissa(row)
129                self.y = self.data._getspectrum(row)
130                from asap import asaplog
131                asaplog.push("Fitting:")
132                i = row
133                out = "Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (self.data.getscan(i),self.data.getbeam(i),self.data.getif(i),self.data.getpol(i), self.data.getcycle(i))
134                asaplog.push(out)
135        self.fitter.setdata(self.x, self.y, self.mask)
136        if self.fitfunc == 'gauss':
137            ps = self.fitter.getparameters()
138            if len(ps) == 0:
139                self.fitter.estimate()
140        try:
141            self.fitter.fit()
142        except RuntimeError, msg:
143            if rcParams['verbose']:
144                print msg
145            else:
146                raise
147        self._fittedrow = row
148        self.fitted = True
149        print_log()
150        return
151
152    def store_fit(self):
153        """
154        Store the fit parameters in the scantable.
155        """
156        if self.fitted and self.data is not None:
157            pars = list(self.fitter.getparameters())
158            fixed = list(self.fitter.getfixedparameters())
159            from asap.asapfit import asapfit
160            fit = asapfit()
161            fit.setparameters(pars)
162            fit.setfixedparameters(fixed)
163            fit.setfunctions(self.fitfuncs)
164            fit.setcomponents(self.components)
165            fit.setframeinfo(self.data._getcoordinfo())
166            self.data._addfit(fit,self._fittedrow)
167
168    #def set_parameters(self, params, fixed=None, component=None):
169    def set_parameters(self,*args,**kwargs):
170        """
171        Set the parameters to be fitted.
172        Parameters:
173              params:    a vector of parameters
174              fixed:     a vector of which parameters are to be held fixed
175                         (default is none)
176              component: in case of multiple gaussians, the index of the
177                         component
178        """
179        component = None
180        fixed = None
181        params = None
182
183        if len(args) and isinstance(args[0],dict):
184            kwargs = args[0]
185        if kwargs.has_key("fixed"): fixed = kwargs["fixed"]
186        if kwargs.has_key("params"): params = kwargs["params"]
187        if len(args) == 2 and isinstance(args[1], int):
188            component = args[1]
189        if self.fitfunc is None:
190            msg = "Please specify a fitting function first."
191            if rcParams['verbose']:
192                print msg
193                return
194            else:
195                raise RuntimeError(msg)
196        if self.fitfunc == "gauss" and component is not None:
197            if not self.fitted and sum(self.fitter.getparameters()) == 0:
198                from numarray import zeros
199                pars = list(zeros(len(self.components)*3))
200                fxd = list(zeros(len(pars)))
201            else:
202                pars = list(self.fitter.getparameters())
203                fxd = list(self.fitter.getfixedparameters())
204            i = 3*component
205            pars[i:i+3] = params
206            fxd[i:i+3] = fixed
207            params = pars
208            fixed = fxd
209        self.fitter.setparameters(params)
210        if fixed is not None:
211            self.fitter.setfixedparameters(fixed)
212        print_log()
213        return
214
215    def set_gauss_parameters(self, peak, centre, fhwm,
216                             peakfixed=0, centerfixed=0,
217                             fhwmfixed=0,
218                             component=0):
219        """
220        Set the Parameters of a 'Gaussian' component, set with set_function.
221        Parameters:
222            peak, centre, fhwm:  The gaussian parameters
223            peakfixed,
224            centerfixed,
225            fhwmfixed:           Optional parameters to indicate if
226                                 the paramters should be held fixed during
227                                 the fitting process. The default is to keep
228                                 all parameters flexible.
229            component:           The number of the component (Default is the
230                                 component 0)
231        """
232        if self.fitfunc != "gauss":
233            msg = "Function only operates on Gaussian components."
234            if rcParams['verbose']:
235                print msg
236                return
237            else:
238                raise ValueError(msg)
239        if 0 <= component < len(self.components):
240            d = {'params':[peak, centre, fhwm],
241                 'fixed':[peakfixed, centerfixed, fhwmfixed]}
242            self.set_parameters(d, component)
243        else:
244            msg = "Please select a valid  component."
245            if rcParams['verbose']:
246                print msg
247                return
248            else:
249                raise ValueError(msg)
250
251    def get_area(self, component=None):
252        """
253        Return the area under the fitted gaussian component.
254        Parameters:
255              component:   the gaussian component selection,
256                           default (None) is the sum of all components
257        Note:
258              This will only work for gaussian fits.
259        """
260        if not self.fitted: return
261        if self.fitfunc == "gauss":
262            pars = list(self.fitter.getparameters())
263            from math import log,pi,sqrt
264            fac = sqrt(pi/log(16.0))
265            areas = []
266            for i in range(len(self.components)):
267                j = i*3
268                cpars = pars[j:j+3]
269                areas.append(fac * cpars[0] * cpars[2])
270        else:
271            return None
272        if component is not None:
273            return areas[component]
274        else:
275            return sum(areas)
276
277    def get_parameters(self, component=None):
278        """
279        Return the fit paramters.
280        Parameters:
281             component:    get the parameters for the specified component
282                           only, default is all components
283        """
284        if not self.fitted:
285            msg = "Not yet fitted."
286            if rcParams['verbose']:
287                print msg
288                return
289            else:
290                raise RuntimeError(msg)
291        pars = list(self.fitter.getparameters())
292        fixed = list(self.fitter.getfixedparameters())
293        area = []
294        if component is not None:
295            if self.fitfunc == "gauss":
296                i = 3*component
297                cpars = pars[i:i+3]
298                cfixed = fixed[i:i+3]
299                a = self.get_area(component)
300                area = [a for i in range(3)]
301            else:
302                cpars = pars
303                cfixed = fixed
304        else:
305            cpars = pars
306            cfixed = fixed
307            if self.fitfunc == "gauss":
308                for c in range(len(self.components)):
309                  a = self.get_area(c)
310                  area += [a for i in range(3)]
311
312        fpars = self._format_pars(cpars, cfixed, area)
313        if rcParams['verbose']:
314            print fpars
315        return {'params':cpars, 'fixed':cfixed, 'formatted': fpars}
316
317    def _format_pars(self, pars, fixed, area):
318        out = ''
319        if self.fitfunc == 'poly':
320            c = 0
321            for i in range(len(pars)):
322                fix = ""
323                if fixed[i]: fix = "(fixed)"
324                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
325                c+=1
326            out = out[:-1]  # remove trailing ','
327        elif self.fitfunc == 'gauss':
328            i = 0
329            c = 0
330            aunit = ''
331            ounit = ''
332            if self.data:
333                aunit = self.data.get_unit()
334                ounit = self.data.get_fluxunit()
335            while i < len(pars):
336                if len(area):
337                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n      area = %3.3f %s %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit, area[i],ounit,aunit)
338                else:
339                    out += '  %2d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s\n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit,ounit,aunit)
340                c+=1
341                i+=3
342        return out
343
344    def get_estimate(self):
345        """
346        Return the parameter estimates (for non-linear functions).
347        """
348        pars = self.fitter.getestimate()
349        fixed = self.fitter.getfixedparameters()
350        if rcParams['verbose']:
351            print self._format_pars(pars,fixed,None)
352        return pars
353
354    def get_residual(self):
355        """
356        Return the residual of the fit.
357        """
358        if not self.fitted:
359            msg = "Not yet fitted."
360            if rcParams['verbose']:
361                print msg
362                return
363            else:
364                raise RuntimeError(msg)
365        return self.fitter.getresidual()
366
367    def get_chi2(self):
368        """
369        Return chi^2.
370        """
371        if not self.fitted:
372            msg = "Not yet fitted."
373            if rcParams['verbose']:
374                print msg
375                return
376            else:
377                raise RuntimeError(msg)
378        ch2 = self.fitter.getchi2()
379        if rcParams['verbose']:
380            print 'Chi^2 = %3.3f' % (ch2)
381        return ch2
382
383    def get_fit(self):
384        """
385        Return the fitted ordinate values.
386        """
387        if not self.fitted:
388            msg = "Not yet fitted."
389            if rcParams['verbose']:
390                print msg
391                return
392            else:
393                raise RuntimeError(msg)
394        return self.fitter.getfit()
395
396    def commit(self):
397        """
398        Return a new scan where the fits have been commited (subtracted)
399        """
400        if not self.fitted:
401            print "Not yet fitted."
402            msg = "Not yet fitted."
403            if rcParams['verbose']:
404                print msg
405                return
406            else:
407                raise RuntimeError(msg)
408        from asap import scantable
409        if not isinstance(self.data, scantable):
410            msg = "Not a scantable"
411            if rcParams['verbose']:
412                print msg
413                return
414            else:
415                raise TypeError(msg)
416        scan = self.data.copy()
417        scan._setspectrum(self.fitter.getresidual())
418        print_log()
419
420    def plot(self, residual=False, components=None, plotparms=False, filename=None):
421        """
422        Plot the last fit.
423        Parameters:
424            residual:    an optional parameter indicating if the residual
425                         should be plotted (default 'False')
426            components:  a list of components to plot, e.g [0,1],
427                         -1 plots the total fit. Default is to only
428                         plot the total fit.
429            plotparms:   Inidicates if the parameter values should be present
430                         on the plot
431        """
432        if not self.fitted:
433            return
434        if not self._p or self._p.is_dead:
435            if rcParams['plotter.gui']:
436                from asap.asaplotgui import asaplotgui as asaplot
437            else:
438                from asap.asaplot import asaplot
439            self._p = asaplot()
440        self._p.hold()
441        self._p.clear()
442        self._p.set_panels()
443        self._p.palette(0)
444        tlab = 'Spectrum'
445        xlab = 'Abcissa'
446        ylab = 'Ordinate'
447        m = None
448        if self.data:
449            tlab = self.data._getsourcename(self._fittedrow)
450            xlab = self.data._getabcissalabel(self._fittedrow)
451            m = self.data._getmask(self._fittedrow)
452            ylab = self.data._get_ordinate_label()
453
454        colours = ["#777777","#bbbbbb","red","orange","purple","green","magenta", "cyan"]
455        self._p.palette(0,colours)
456        self._p.set_line(label='Spectrum')
457        self._p.plot(self.x, self.y, m)
458        if residual:
459            self._p.palette(1)
460            self._p.set_line(label='Residual')
461            self._p.plot(self.x, self.get_residual(), m)
462        self._p.palette(2)
463        if components is not None:
464            cs = components
465            if isinstance(components,int): cs = [components]
466            if plotparms:
467                self._p.text(0.15,0.15,str(self.get_parameters()['formatted']),size=8)
468            n = len(self.components)
469            self._p.palette(3)
470            for c in cs:
471                if 0 <= c < n:
472                    lab = self.fitfuncs[c]+str(c)
473                    self._p.set_line(label=lab)
474                    self._p.plot(self.x, self.fitter.evaluate(c), m)
475                elif c == -1:
476                    self._p.palette(2)
477                    self._p.set_line(label="Total Fit")
478                    self._p.plot(self.x, self.get_fit(), m)
479        else:
480            self._p.palette(2)
481            self._p.set_line(label='Fit')
482            self._p.plot(self.x, self.get_fit(), m)
483        xlim=[min(self.x),max(self.x)]
484        self._p.axes.set_xlim(xlim)
485        self._p.set_axes('xlabel',xlab)
486        self._p.set_axes('ylabel',ylab)
487        self._p.set_axes('title',tlab)
488        self._p.release()
489        if (not rcParams['plotter.gui']):
490            self._p.save(filename)
491        print_log()
492
493    def auto_fit(self, insitu=None):
494        """
495        Return a scan where the function is applied to all rows for
496        all Beams/IFs/Pols.
497
498        """
499        from asap import scantable
500        if not isinstance(self.data, scantable) :
501            msg = "Data is not a scantable"
502            if rcParams['verbose']:
503                print msg
504                return
505            else:
506                raise TypeError(msg)
507        if insitu is None: insitu = rcParams['insitu']
508        if not insitu:
509            scan = self.data.copy()
510        else:
511            scan = self.data
512        rows = xrange(scan.nrow())
513        from asap import asaplog
514        asaplog.push("Fitting:")
515        for r in rows:
516            out = " Scan[%d] Beam[%d] IF[%d] Pol[%d] Cycle[%d]" % (scan.getscan(r),scan.getbeam(r),scan.getif(r),scan.getpol(r), scan.getcycle(r))
517            asaplog.push(out, False)
518            self.x = scan._getabcissa(r)
519            self.y = scan._getspectrum(r)
520            self.data = None
521            self.fit()
522            x = self.get_parameters()
523            scan._setspectrum(self.fitter.getresidual(), r)
524        print_log()
525        return scan
526
Note: See TracBrowser for help on using the repository browser.