source: trunk/python/asapfitter.py @ 526

Last change on this file since 526 was 526, checked in by mar637, 19 years ago

added more help

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.9 KB
RevLine 
[113]1import _asap
[259]2from asap import rcParams
[113]3
4class fitter:
5    """
6    The fitting class for ASAP.
7    """
8    def _verbose(self, *args):
9        """
10        Set stdout output.
11        """
12        if type(args[0]) is bool:
13            self._vb = args[0]
14            return
15        elif len(args) == 0:
16            return self._vb
17       
18    def __init__(self):
19        """
20        Create a fitter object. No state is set.
21        """
22        self.fitter = _asap.fitter()
23        self.x = None
24        self.y = None
25        self.mask = None
26        self.fitfunc = None
[515]27        self.fitfuncs = None
[113]28        self.fitted = False
29        self.data = None
[515]30        self.components = 0
31        self._fittedrow = 0
[113]32        self._p = None
33        self._vb = True
[515]34        self._selection = None
[113]35
36    def set_data(self, xdat, ydat, mask=None):
37        """
[158]38        Set the absissa and ordinate for the fit. Also set the mask
[113]39        indicationg valid points.
40        This can be used for data vectors retrieved from a scantable.
41        For scantable fitting use 'fitter.set_scan(scan, mask)'.
42        Parameters:
[158]43            xdat:    the abcissa values
[113]44            ydat:    the ordinate values
45            mask:    an optional mask
46       
47        """
48        self.fitted = False
49        self.x = xdat
50        self.y = ydat
51        if mask == None:
52            from numarray import ones
53            self.mask = ones(len(xdat))
54        else:
55            self.mask = mask
56        return
57
58    def set_scan(self, thescan=None, mask=None):
59        """
60        Set the 'data' (a scantable) of the fitter.
61        Parameters:
62            thescan:     a scantable
63            mask:        a msk retireved from the scantable
64        """
65        if not thescan:
66            print "Please give a correct scan"
67        self.fitted = False
68        self.data = thescan
69        if mask is None:
70            from numarray import ones
71            self.mask = ones(self.data.nchan())
72        else:
73            self.mask = mask
74        return
75
76    def set_function(self, **kwargs):
77        """
78        Set the function to be fit.
79        Parameters:
80            poly:    use a polynomial of the order given
81            gauss:   fit the number of gaussian specified
82        Example:
83            fitter.set_function(gauss=2) # will fit two gaussians
84            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
85        """
[515]86        #default poly order 0       
87        n=0
[113]88        if kwargs.has_key('poly'):
89            self.fitfunc = 'poly'
90            n = kwargs.get('poly')
[515]91            self.components = [n]
[113]92        elif kwargs.has_key('gauss'):
93            n = kwargs.get('gauss')
94            self.fitfunc = 'gauss'
[515]95            self.fitfuncs = [ 'gauss' for i in range(n) ]
96            self.components = [ 3 for i in range(n) ]
97        else:
98            print "Invalid function type."
99            return
[113]100        self.fitter.setexpression(self.fitfunc,n)
101        return
102           
[515]103    def fit(self, row=0):
[113]104        """
105        Execute the actual fitting process. All the state has to be set.
106        Parameters:
[526]107            row:    specify the row in the scantable
[113]108        Example:
[515]109            s = scantable('myscan.asap')
110            s.set_cursor(thepol=1)        # select second pol
[113]111            f = fitter()
112            f.set_scan(s)
113            f.set_function(poly=0)
[515]114            f.fit(row=0)                  # fit first row
[113]115        """
116        if ((self.x is None or self.y is None) and self.data is None) \
117               or self.fitfunc is None:
118            print "Fitter not yet initialised. Please set data & fit function"
119            return
120        else:
121            if self.data is not None:
[515]122                self.x = self.data._getabcissa(row)
123                self.y = self.data._getspectrum(row)
[113]124                print "Fitting:"
[259]125                vb = self.data._vb
126                self.data._vb = True
[515]127                self.selection = self.data.get_cursor()
[259]128                self.data._vb = vb
[515]129        self.fitter.setdata(self.x, self.y, self.mask)
[113]130        if self.fitfunc == 'gauss':
131            ps = self.fitter.getparameters()
132            if len(ps) == 0:
133                self.fitter.estimate()
134        self.fitter.fit()
[515]135        self._fittedrow = row
[113]136        self.fitted = True
137        return
138
[515]139    def store_fit(self):
[526]140        """
141        Store the fit parameters in the scantable.
142        """
[515]143        if self.fitted and self.data is not None:
144            pars = list(self.fitter.getparameters())
145            fixed = list(self.fitter.getfixedparameters())
146            self.data._addfit(self._fittedrow, pars, fixed,
147                              self.fitfuncs, self.components)
148
149    def set_parameters(self, params, fixed=None, component=None):
[526]150        """
151        Set the parameters to be fitted.
152        Parameters:
153              params:    a vector of parameters
154              fixed:     a vector of which parameters are to be held fixed
155                         (default is none)
156              component: in case of multiple gaussians, the index of the
157                         component
158             """
[515]159        if self.fitfunc is None:
160            print "Please specify a fitting function first."
161            return
162        if self.fitfunc == "gauss" and component is not None:
163            if not self.fitted:
164                from numarray import zeros
165                pars = list(zeros(len(self.components)*3))
166                fxd = list(zeros(len(pars)))
167            else:
168                pars = list(self.fitter.getparameters())             
169                fxd = list(self.fitter.getfixedparameters())
170            i = 3*component
171            pars[i:i+3] = params
172            fxd[i:i+3] = fixed
173            params = pars
174            fixed = fxd         
[113]175        self.fitter.setparameters(params)
176        if fixed is not None:
177            self.fitter.setfixedparameters(fixed)
178        return
[515]179
180    def set_gauss_parameters(self, peak, centre, fhwm,
181                             peakfixed=False, centerfixed=False,
182                             fhwmfixed=False,
183                             component=0):
[113]184        """
[515]185        Set the Parameters of a 'Gaussian' component, set with set_function.
186        Parameters:
187            peak, centre, fhwm:  The gaussian parameters
188            peakfixed,
189            centerfixed,
190            fhwmfixed:           Optional parameters to indicate if
191                                 the paramters should be held fixed during
192                                 the fitting process. The default is to keep
193                                 all parameters flexible.
[526]194            component:           The number of the component (Default is the
195                                 component 0)
[515]196        """
197        if self.fitfunc != "gauss":
198            print "Function only operates on Gaussian components."
199            return
200        if 0 <= component < len(self.components):
201            self.set_parameters([peak, centre, fhwm],
202                                [peakfixed, centerfixed, fhwmfixed],
203                                component)
204        else:
205            print "Please select a valid  component."
206            return
207       
208    def get_parameters(self, component=None):
209        """
[113]210        Return the fit paramters.
[526]211        Parameters:
212             component:    get the parameters for the specified component
213                           only, default is all components
[113]214        """
215        if not self.fitted:
216            print "Not yet fitted."
217        pars = list(self.fitter.getparameters())
218        fixed = list(self.fitter.getfixedparameters())
[515]219        if component is not None:           
220            if self.fitfunc == "gauss":
221                i = 3*component
222                cpars = pars[i:i+3]
223                cfixed = fixed[i:i+3]
224            else:
225                cpars = pars
226                cfixed = fixed               
227        else:
228            cpars = pars
229            cfixed = fixed
230        fpars = self._format_pars(cpars, cfixed)
[113]231        if self._vb:
[515]232            print fpars
233        return cpars, cfixed, fpars
[113]234   
[515]235    def _format_pars(self, pars, fixed):
[113]236        out = ''
237        if self.fitfunc == 'poly':
238            c = 0
[515]239            for i in range(len(pars)):
240                fix = ""
241                if fixed[i]: fix = "(fixed)"
242                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
[113]243                c+=1
[515]244            out = out[:-1]  # remove trailing ','
[113]245        elif self.fitfunc == 'gauss':
246            i = 0
247            c = 0
[515]248            aunit = ''
249            ounit = ''
[113]250            if self.data:
[515]251                aunit = self.data.get_unit()
252                ounit = self.data.get_fluxunit()
[113]253            while i < len(pars):
[515]254                out += '  %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
[113]255                c+=1
256                i+=3
257        return out
258       
259    def get_estimate(self):
260        """
[515]261        Return the parameter estimates (for non-linear functions).
[113]262        """
263        pars = self.fitter.getestimate()
264        if self._vb:
265            print self._format_pars(pars)
266        return pars
267       
268
269    def get_residual(self):
270        """
271        Return the residual of the fit.
272        """
273        if not self.fitted:
274            print "Not yet fitted."
275        return self.fitter.getresidual()
276
277    def get_chi2(self):
278        """
279        Return chi^2.
280        """
281        if not self.fitted:
282            print "Not yet fitted."
283        ch2 = self.fitter.getchi2()
284        if self._vb:
285            print 'Chi^2 = %3.3f' % (ch2)
286        return ch2
287
288    def get_fit(self):
289        """
290        Return the fitted ordinate values.
291        """
292        if not self.fitted:
293            print "Not yet fitted."
294        return self.fitter.getfit()
295
296    def commit(self):
297        """
[526]298        Return a new scan where the fits have been commited (subtracted)
[113]299        """
300        if not self.fitted:
301            print "Not yet fitted."
302        if self.data is not scantable:
303            print "Only works with scantables"
304            return
305        scan = self.data.copy()
[259]306        scan._setspectrum(self.fitter.getresidual())
[113]307
[526]308    def plot(self, residual=False, components=None, plotparms=False):
[113]309        """
310        Plot the last fit.
311        Parameters:
312            residual:    an optional parameter indicating if the residual
313                         should be plotted (default 'False')
[526]314            components:  a list of components to plot, e.g [0,1],
315                         -1 plots the total fit. Default is to only
316                         plot the total fit.
317            plotparms:   Inidicates if the parameter values should be present
318                         on the plot
[113]319        """
320        if not self.fitted:
321            return
322        if not self._p:
323            from asap.asaplot import ASAPlot
324            self._p = ASAPlot()
[298]325        if self._p.is_dead:
[190]326            from asap.asaplot import ASAPlot
327            self._p = ASAPlot()
[113]328        self._p.clear()
[515]329        self._p.set_panels()
330        self._p.palette(1)
[113]331        tlab = 'Spectrum'
[515]332        xlab = 'Abcissa'       
333        m = ()
[113]334        if self.data:
[515]335            tlab = self.data._getsourcename(self._fittedrow)
336            xlab = self.data._getabcissalabel(self._fittedrow)
337            m = self.data._getmask(self._fittedrow)
338            ylab = r'Flux'
339
340        colours = ["grey60","grey80","red","orange","purple","yellow","magenta", "cyan"]
341        self._p.palette(1,colours)
342        self._p.set_line(label='Spectrum')
[113]343        self._p.plot(self.x, self.y, m)
344        if residual:
[515]345            self._p.palette(2)
346            self._p.set_line(label='Residual')
[113]347            self._p.plot(self.x, self.get_residual(), m)
[515]348        self._p.palette(3)
349        if components is not None:
350            cs = components
351            if isinstance(components,int): cs = [components]
[526]352            if plotparms:
353                self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
[515]354            n = len(self.components)
355            self._p.palette(4)
356            for c in cs:
357                if 0 <= c < n:
358                    lab = self.fitfuncs[c]+str(c)
359                    self._p.set_line(label=lab)
360                    self._p.plot(self.x, self.fitter.evaluate(c), m)
361                elif c == -1:
362                    self._p.palette(3)
363                    self._p.set_line(label="Total Fit")
364                    self._p.plot(self.x, self.get_fit(), m)                   
365        else:
366            self._p.palette(3)
367            self._p.set_line(label='Fit')
368            self._p.plot(self.x, self.get_fit(), m)
[113]369        self._p.set_axes('xlabel',xlab)
370        self._p.set_axes('ylabel',ylab)
371        self._p.set_axes('title',tlab)
372        self._p.release()
373
[259]374    def auto_fit(self, insitu=None):
[113]375        """
[515]376        Return a scan where the function is applied to all rows for
377        all Beams/IFs/Pols.
[113]378       
379        """
380        from asap import scantable
[515]381        if not isinstance(self.data, scantable) :
[113]382            print "Only works with scantables"
383            return
[259]384        if insitu is None: insitu = rcParams['insitu']
385        if not insitu:
386            scan = self.data.copy()
387        else:
388            scan = self.data
389        vb = scan._vb
390        scan._vb = False
391        sel = scan.get_cursor()
[159]392        rows = range(scan.nrow())
[113]393        for i in range(scan.nbeam()):
394            scan.setbeam(i)
395            for j in range(scan.nif()):
396                scan.setif(j)
397                for k in range(scan.npol()):
398                    scan.setpol(k)
399                    if self._vb:
400                        print "Fitting:"
401                        print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
[159]402                    for iRow in rows:
[259]403                        self.x = scan._getabcissa(iRow)
404                        self.y = scan._getspectrum(iRow)
[159]405                        self.data = None
406                        self.fit()                   
[113]407                        x = self.get_parameters()
[259]408                        scan._setspectrum(self.fitter.getresidual(),iRow)
409        scan.set_cursor(sel[0],sel[1],sel[2])
410        scan._vb = vb
411        if not insitu:
412            return scan
Note: See TracBrowser for help on using the repository browser.