source: trunk/python/asapfitter.py @ 526

Last change on this file since 526 was 526, checked in by mar637, 19 years ago

added more help

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.9 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5    """
6    The fitting class for ASAP.
7    """
8    def _verbose(self, *args):
9        """
10        Set stdout output.
11        """
12        if type(args[0]) is bool:
13            self._vb = args[0]
14            return
15        elif len(args) == 0:
16            return self._vb
17       
18    def __init__(self):
19        """
20        Create a fitter object. No state is set.
21        """
22        self.fitter = _asap.fitter()
23        self.x = None
24        self.y = None
25        self.mask = None
26        self.fitfunc = None
27        self.fitfuncs = None
28        self.fitted = False
29        self.data = None
30        self.components = 0
31        self._fittedrow = 0
32        self._p = None
33        self._vb = True
34        self._selection = None
35
36    def set_data(self, xdat, ydat, mask=None):
37        """
38        Set the absissa and ordinate for the fit. Also set the mask
39        indicationg valid points.
40        This can be used for data vectors retrieved from a scantable.
41        For scantable fitting use 'fitter.set_scan(scan, mask)'.
42        Parameters:
43            xdat:    the abcissa values
44            ydat:    the ordinate values
45            mask:    an optional mask
46       
47        """
48        self.fitted = False
49        self.x = xdat
50        self.y = ydat
51        if mask == None:
52            from numarray import ones
53            self.mask = ones(len(xdat))
54        else:
55            self.mask = mask
56        return
57
58    def set_scan(self, thescan=None, mask=None):
59        """
60        Set the 'data' (a scantable) of the fitter.
61        Parameters:
62            thescan:     a scantable
63            mask:        a msk retireved from the scantable
64        """
65        if not thescan:
66            print "Please give a correct scan"
67        self.fitted = False
68        self.data = thescan
69        if mask is None:
70            from numarray import ones
71            self.mask = ones(self.data.nchan())
72        else:
73            self.mask = mask
74        return
75
76    def set_function(self, **kwargs):
77        """
78        Set the function to be fit.
79        Parameters:
80            poly:    use a polynomial of the order given
81            gauss:   fit the number of gaussian specified
82        Example:
83            fitter.set_function(gauss=2) # will fit two gaussians
84            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
85        """
86        #default poly order 0       
87        n=0
88        if kwargs.has_key('poly'):
89            self.fitfunc = 'poly'
90            n = kwargs.get('poly')
91            self.components = [n]
92        elif kwargs.has_key('gauss'):
93            n = kwargs.get('gauss')
94            self.fitfunc = 'gauss'
95            self.fitfuncs = [ 'gauss' for i in range(n) ]
96            self.components = [ 3 for i in range(n) ]
97        else:
98            print "Invalid function type."
99            return
100        self.fitter.setexpression(self.fitfunc,n)
101        return
102           
103    def fit(self, row=0):
104        """
105        Execute the actual fitting process. All the state has to be set.
106        Parameters:
107            row:    specify the row in the scantable
108        Example:
109            s = scantable('myscan.asap')
110            s.set_cursor(thepol=1)        # select second pol
111            f = fitter()
112            f.set_scan(s)
113            f.set_function(poly=0)
114            f.fit(row=0)                  # fit first row
115        """
116        if ((self.x is None or self.y is None) and self.data is None) \
117               or self.fitfunc is None:
118            print "Fitter not yet initialised. Please set data & fit function"
119            return
120        else:
121            if self.data is not None:
122                self.x = self.data._getabcissa(row)
123                self.y = self.data._getspectrum(row)
124                print "Fitting:"
125                vb = self.data._vb
126                self.data._vb = True
127                self.selection = self.data.get_cursor()
128                self.data._vb = vb
129        self.fitter.setdata(self.x, self.y, self.mask)
130        if self.fitfunc == 'gauss':
131            ps = self.fitter.getparameters()
132            if len(ps) == 0:
133                self.fitter.estimate()
134        self.fitter.fit()
135        self._fittedrow = row
136        self.fitted = True
137        return
138
139    def store_fit(self):
140        """
141        Store the fit parameters in the scantable.
142        """
143        if self.fitted and self.data is not None:
144            pars = list(self.fitter.getparameters())
145            fixed = list(self.fitter.getfixedparameters())
146            self.data._addfit(self._fittedrow, pars, fixed,
147                              self.fitfuncs, self.components)
148
149    def set_parameters(self, params, fixed=None, component=None):
150        """
151        Set the parameters to be fitted.
152        Parameters:
153              params:    a vector of parameters
154              fixed:     a vector of which parameters are to be held fixed
155                         (default is none)
156              component: in case of multiple gaussians, the index of the
157                         component
158             """
159        if self.fitfunc is None:
160            print "Please specify a fitting function first."
161            return
162        if self.fitfunc == "gauss" and component is not None:
163            if not self.fitted:
164                from numarray import zeros
165                pars = list(zeros(len(self.components)*3))
166                fxd = list(zeros(len(pars)))
167            else:
168                pars = list(self.fitter.getparameters())             
169                fxd = list(self.fitter.getfixedparameters())
170            i = 3*component
171            pars[i:i+3] = params
172            fxd[i:i+3] = fixed
173            params = pars
174            fixed = fxd         
175        self.fitter.setparameters(params)
176        if fixed is not None:
177            self.fitter.setfixedparameters(fixed)
178        return
179
180    def set_gauss_parameters(self, peak, centre, fhwm,
181                             peakfixed=False, centerfixed=False,
182                             fhwmfixed=False,
183                             component=0):
184        """
185        Set the Parameters of a 'Gaussian' component, set with set_function.
186        Parameters:
187            peak, centre, fhwm:  The gaussian parameters
188            peakfixed,
189            centerfixed,
190            fhwmfixed:           Optional parameters to indicate if
191                                 the paramters should be held fixed during
192                                 the fitting process. The default is to keep
193                                 all parameters flexible.
194            component:           The number of the component (Default is the
195                                 component 0)
196        """
197        if self.fitfunc != "gauss":
198            print "Function only operates on Gaussian components."
199            return
200        if 0 <= component < len(self.components):
201            self.set_parameters([peak, centre, fhwm],
202                                [peakfixed, centerfixed, fhwmfixed],
203                                component)
204        else:
205            print "Please select a valid  component."
206            return
207       
208    def get_parameters(self, component=None):
209        """
210        Return the fit paramters.
211        Parameters:
212             component:    get the parameters for the specified component
213                           only, default is all components
214        """
215        if not self.fitted:
216            print "Not yet fitted."
217        pars = list(self.fitter.getparameters())
218        fixed = list(self.fitter.getfixedparameters())
219        if component is not None:           
220            if self.fitfunc == "gauss":
221                i = 3*component
222                cpars = pars[i:i+3]
223                cfixed = fixed[i:i+3]
224            else:
225                cpars = pars
226                cfixed = fixed               
227        else:
228            cpars = pars
229            cfixed = fixed
230        fpars = self._format_pars(cpars, cfixed)
231        if self._vb:
232            print fpars
233        return cpars, cfixed, fpars
234   
235    def _format_pars(self, pars, fixed):
236        out = ''
237        if self.fitfunc == 'poly':
238            c = 0
239            for i in range(len(pars)):
240                fix = ""
241                if fixed[i]: fix = "(fixed)"
242                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
243                c+=1
244            out = out[:-1]  # remove trailing ','
245        elif self.fitfunc == 'gauss':
246            i = 0
247            c = 0
248            aunit = ''
249            ounit = ''
250            if self.data:
251                aunit = self.data.get_unit()
252                ounit = self.data.get_fluxunit()
253            while i < len(pars):
254                out += '  %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
255                c+=1
256                i+=3
257        return out
258       
259    def get_estimate(self):
260        """
261        Return the parameter estimates (for non-linear functions).
262        """
263        pars = self.fitter.getestimate()
264        if self._vb:
265            print self._format_pars(pars)
266        return pars
267       
268
269    def get_residual(self):
270        """
271        Return the residual of the fit.
272        """
273        if not self.fitted:
274            print "Not yet fitted."
275        return self.fitter.getresidual()
276
277    def get_chi2(self):
278        """
279        Return chi^2.
280        """
281        if not self.fitted:
282            print "Not yet fitted."
283        ch2 = self.fitter.getchi2()
284        if self._vb:
285            print 'Chi^2 = %3.3f' % (ch2)
286        return ch2
287
288    def get_fit(self):
289        """
290        Return the fitted ordinate values.
291        """
292        if not self.fitted:
293            print "Not yet fitted."
294        return self.fitter.getfit()
295
296    def commit(self):
297        """
298        Return a new scan where the fits have been commited (subtracted)
299        """
300        if not self.fitted:
301            print "Not yet fitted."
302        if self.data is not scantable:
303            print "Only works with scantables"
304            return
305        scan = self.data.copy()
306        scan._setspectrum(self.fitter.getresidual())
307
308    def plot(self, residual=False, components=None, plotparms=False):
309        """
310        Plot the last fit.
311        Parameters:
312            residual:    an optional parameter indicating if the residual
313                         should be plotted (default 'False')
314            components:  a list of components to plot, e.g [0,1],
315                         -1 plots the total fit. Default is to only
316                         plot the total fit.
317            plotparms:   Inidicates if the parameter values should be present
318                         on the plot
319        """
320        if not self.fitted:
321            return
322        if not self._p:
323            from asap.asaplot import ASAPlot
324            self._p = ASAPlot()
325        if self._p.is_dead:
326            from asap.asaplot import ASAPlot
327            self._p = ASAPlot()
328        self._p.clear()
329        self._p.set_panels()
330        self._p.palette(1)
331        tlab = 'Spectrum'
332        xlab = 'Abcissa'       
333        m = ()
334        if self.data:
335            tlab = self.data._getsourcename(self._fittedrow)
336            xlab = self.data._getabcissalabel(self._fittedrow)
337            m = self.data._getmask(self._fittedrow)
338            ylab = r'Flux'
339
340        colours = ["grey60","grey80","red","orange","purple","yellow","magenta", "cyan"]
341        self._p.palette(1,colours)
342        self._p.set_line(label='Spectrum')
343        self._p.plot(self.x, self.y, m)
344        if residual:
345            self._p.palette(2)
346            self._p.set_line(label='Residual')
347            self._p.plot(self.x, self.get_residual(), m)
348        self._p.palette(3)
349        if components is not None:
350            cs = components
351            if isinstance(components,int): cs = [components]
352            if plotparms:
353                self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
354            n = len(self.components)
355            self._p.palette(4)
356            for c in cs:
357                if 0 <= c < n:
358                    lab = self.fitfuncs[c]+str(c)
359                    self._p.set_line(label=lab)
360                    self._p.plot(self.x, self.fitter.evaluate(c), m)
361                elif c == -1:
362                    self._p.palette(3)
363                    self._p.set_line(label="Total Fit")
364                    self._p.plot(self.x, self.get_fit(), m)                   
365        else:
366            self._p.palette(3)
367            self._p.set_line(label='Fit')
368            self._p.plot(self.x, self.get_fit(), m)
369        self._p.set_axes('xlabel',xlab)
370        self._p.set_axes('ylabel',ylab)
371        self._p.set_axes('title',tlab)
372        self._p.release()
373
374    def auto_fit(self, insitu=None):
375        """
376        Return a scan where the function is applied to all rows for
377        all Beams/IFs/Pols.
378       
379        """
380        from asap import scantable
381        if not isinstance(self.data, scantable) :
382            print "Only works with scantables"
383            return
384        if insitu is None: insitu = rcParams['insitu']
385        if not insitu:
386            scan = self.data.copy()
387        else:
388            scan = self.data
389        vb = scan._vb
390        scan._vb = False
391        sel = scan.get_cursor()
392        rows = range(scan.nrow())
393        for i in range(scan.nbeam()):
394            scan.setbeam(i)
395            for j in range(scan.nif()):
396                scan.setif(j)
397                for k in range(scan.npol()):
398                    scan.setpol(k)
399                    if self._vb:
400                        print "Fitting:"
401                        print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
402                    for iRow in rows:
403                        self.x = scan._getabcissa(iRow)
404                        self.y = scan._getspectrum(iRow)
405                        self.data = None
406                        self.fit()                   
407                        x = self.get_parameters()
408                        scan._setspectrum(self.fitter.getresidual(),iRow)
409        scan.set_cursor(sel[0],sel[1],sel[2])
410        scan._vb = vb
411        if not insitu:
412            return scan
Note: See TracBrowser for help on using the repository browser.