source: trunk/python/asapfitter.py @ 626

Last change on this file since 626 was 626, checked in by mar637, 19 years ago

fix for asap0019 from Release-1

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.0 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5    """
6    The fitting class for ASAP.
7    """
8    def _verbose(self, *args):
9        """
10        Set stdout output.
11        """
12        if type(args[0]) is bool:
13            self._vb = args[0]
14            return
15        elif len(args) == 0:
16            return self._vb
17       
18    def __init__(self):
19        """
20        Create a fitter object. No state is set.
21        """
22        self.fitter = _asap.fitter()
23        self.x = None
24        self.y = None
25        self.mask = None
26        self.fitfunc = None
27        self.fitfuncs = None
28        self.fitted = False
29        self.data = None
30        self.components = 0
31        self._fittedrow = 0
32        self._p = None
33        self._vb = True
34        self._selection = None
35
36    def set_data(self, xdat, ydat, mask=None):
37        """
38        Set the absissa and ordinate for the fit. Also set the mask
39        indicationg valid points.
40        This can be used for data vectors retrieved from a scantable.
41        For scantable fitting use 'fitter.set_scan(scan, mask)'.
42        Parameters:
43            xdat:    the abcissa values
44            ydat:    the ordinate values
45            mask:    an optional mask
46       
47        """
48        self.fitted = False
49        self.x = xdat
50        self.y = ydat
51        if mask == None:
52            from numarray import ones
53            self.mask = ones(len(xdat))
54        else:
55            self.mask = mask
56        return
57
58    def set_scan(self, thescan=None, mask=None):
59        """
60        Set the 'data' (a scantable) of the fitter.
61        Parameters:
62            thescan:     a scantable
63            mask:        a msk retireved from the scantable
64        """
65        if not thescan:
66            print "Please give a correct scan"
67        self.fitted = False
68        self.data = thescan
69        if mask is None:
70            from numarray import ones
71            self.mask = ones(self.data.nchan())
72        else:
73            self.mask = mask
74        return
75
76    def set_function(self, **kwargs):
77        """
78        Set the function to be fit.
79        Parameters:
80            poly:    use a polynomial of the order given
81            gauss:   fit the number of gaussian specified
82        Example:
83            fitter.set_function(gauss=2) # will fit two gaussians
84            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
85        """
86        #default poly order 0       
87        n=0
88        if kwargs.has_key('poly'):
89            self.fitfunc = 'poly'
90            n = kwargs.get('poly')
91            self.components = [n]
92        elif kwargs.has_key('gauss'):
93            n = kwargs.get('gauss')
94            self.fitfunc = 'gauss'
95            self.fitfuncs = [ 'gauss' for i in range(n) ]
96            self.components = [ 3 for i in range(n) ]
97        else:
98            print "Invalid function type."
99            return
100        self.fitter.setexpression(self.fitfunc,n)
101        return
102           
103    def fit(self, row=0):
104        """
105        Execute the actual fitting process. All the state has to be set.
106        Parameters:
107            row:    specify the row in the scantable
108        Example:
109            s = scantable('myscan.asap')
110            s.set_cursor(thepol=1)        # select second pol
111            f = fitter()
112            f.set_scan(s)
113            f.set_function(poly=0)
114            f.fit(row=0)                  # fit first row
115        """
116        if ((self.x is None or self.y is None) and self.data is None) \
117               or self.fitfunc is None:
118            print "Fitter not yet initialised. Please set data & fit function"
119            return
120        else:
121            if self.data is not None:
122                self.x = self.data._getabcissa(row)
123                self.y = self.data._getspectrum(row)
124                print "Fitting:"
125                vb = self.data._vb
126                self.data._vb = True
127                self.selection = self.data.get_cursor()
128                self.data._vb = vb
129        self.fitter.setdata(self.x, self.y, self.mask)
130        if self.fitfunc == 'gauss':
131            ps = self.fitter.getparameters()
132            if len(ps) == 0:
133                self.fitter.estimate()
134        try:
135            self.fitter.fit()
136        except RuntimeError, msg:
137            print msg           
138        self._fittedrow = row
139        self.fitted = True
140        return
141
142    def store_fit(self):
143        """
144        Store the fit parameters in the scantable.
145        """
146        if self.fitted and self.data is not None:
147            pars = list(self.fitter.getparameters())
148            fixed = list(self.fitter.getfixedparameters())
149            self.data._addfit(self._fittedrow, pars, fixed,
150                              self.fitfuncs, self.components)
151
152    def set_parameters(self, params, fixed=None, component=None):
153        """
154        Set the parameters to be fitted.
155        Parameters:
156              params:    a vector of parameters
157              fixed:     a vector of which parameters are to be held fixed
158                         (default is none)
159              component: in case of multiple gaussians, the index of the
160                         component
161             """
162        if self.fitfunc is None:
163            print "Please specify a fitting function first."
164            return
165        if self.fitfunc == "gauss" and component is not None:
166            if not self.fitted:
167                from numarray import zeros
168                pars = list(zeros(len(self.components)*3))
169                fxd = list(zeros(len(pars)))
170            else:
171                pars = list(self.fitter.getparameters())             
172                fxd = list(self.fitter.getfixedparameters())
173            i = 3*component
174            pars[i:i+3] = params
175            fxd[i:i+3] = fixed
176            params = pars
177            fixed = fxd         
178        self.fitter.setparameters(params)
179        if fixed is not None:
180            self.fitter.setfixedparameters(fixed)
181        return
182
183    def set_gauss_parameters(self, peak, centre, fhwm,
184                             peakfixed=False, centerfixed=False,
185                             fhwmfixed=False,
186                             component=0):
187        """
188        Set the Parameters of a 'Gaussian' component, set with set_function.
189        Parameters:
190            peak, centre, fhwm:  The gaussian parameters
191            peakfixed,
192            centerfixed,
193            fhwmfixed:           Optional parameters to indicate if
194                                 the paramters should be held fixed during
195                                 the fitting process. The default is to keep
196                                 all parameters flexible.
197            component:           The number of the component (Default is the
198                                 component 0)
199        """
200        if self.fitfunc != "gauss":
201            print "Function only operates on Gaussian components."
202            return
203        if 0 <= component < len(self.components):
204            self.set_parameters([peak, centre, fhwm],
205                                [peakfixed, centerfixed, fhwmfixed],
206                                component)
207        else:
208            print "Please select a valid  component."
209            return
210       
211    def get_parameters(self, component=None):
212        """
213        Return the fit paramters.
214        Parameters:
215             component:    get the parameters for the specified component
216                           only, default is all components
217        """
218        if not self.fitted:
219            print "Not yet fitted."
220        pars = list(self.fitter.getparameters())
221        fixed = list(self.fitter.getfixedparameters())
222        if component is not None:           
223            if self.fitfunc == "gauss":
224                i = 3*component
225                cpars = pars[i:i+3]
226                cfixed = fixed[i:i+3]
227            else:
228                cpars = pars
229                cfixed = fixed               
230        else:
231            cpars = pars
232            cfixed = fixed
233        fpars = self._format_pars(cpars, cfixed)
234        if self._vb:
235            print fpars
236        return cpars, cfixed, fpars
237   
238    def _format_pars(self, pars, fixed):
239        out = ''
240        if self.fitfunc == 'poly':
241            c = 0
242            for i in range(len(pars)):
243                fix = ""
244                if fixed[i]: fix = "(fixed)"
245                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
246                c+=1
247            out = out[:-1]  # remove trailing ','
248        elif self.fitfunc == 'gauss':
249            i = 0
250            c = 0
251            aunit = ''
252            ounit = ''
253            if self.data:
254                aunit = self.data.get_unit()
255                ounit = self.data.get_fluxunit()
256            while i < len(pars):
257                out += '  %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
258                c+=1
259                i+=3
260        return out
261       
262    def get_estimate(self):
263        """
264        Return the parameter estimates (for non-linear functions).
265        """
266        pars = self.fitter.getestimate()
267        if self._vb:
268            print self._format_pars(pars)
269        return pars
270       
271
272    def get_residual(self):
273        """
274        Return the residual of the fit.
275        """
276        if not self.fitted:
277            print "Not yet fitted."
278        return self.fitter.getresidual()
279
280    def get_chi2(self):
281        """
282        Return chi^2.
283        """
284        if not self.fitted:
285            print "Not yet fitted."
286        ch2 = self.fitter.getchi2()
287        if self._vb:
288            print 'Chi^2 = %3.3f' % (ch2)
289        return ch2
290
291    def get_fit(self):
292        """
293        Return the fitted ordinate values.
294        """
295        if not self.fitted:
296            print "Not yet fitted."
297        return self.fitter.getfit()
298
299    def commit(self):
300        """
301        Return a new scan where the fits have been commited (subtracted)
302        """
303        if not self.fitted:
304            print "Not yet fitted."
305        if self.data is not scantable:
306            print "Only works with scantables"
307            return
308        scan = self.data.copy()
309        scan._setspectrum(self.fitter.getresidual())
310
311    def plot(self, residual=False, components=None, plotparms=False):
312        """
313        Plot the last fit.
314        Parameters:
315            residual:    an optional parameter indicating if the residual
316                         should be plotted (default 'False')
317            components:  a list of components to plot, e.g [0,1],
318                         -1 plots the total fit. Default is to only
319                         plot the total fit.
320            plotparms:   Inidicates if the parameter values should be present
321                         on the plot
322        """
323        if not self.fitted:
324            return
325        if not self._p:
326            from asap.asaplot import ASAPlot
327            self._p = ASAPlot()
328        if self._p.is_dead:
329            from asap.asaplot import ASAPlot
330            self._p = ASAPlot()
331        self._p.clear()
332        self._p.set_panels()
333        self._p.palette(1)
334        tlab = 'Spectrum'
335        xlab = 'Abcissa'       
336        m = ()
337        if self.data:
338            tlab = self.data._getsourcename(self._fittedrow)
339            xlab = self.data._getabcissalabel(self._fittedrow)
340            m = self.data._getmask(self._fittedrow)
341            ylab = self.data._get_ordinate_label()
342
343        colours = ["grey60","grey80","red","orange","purple","green","magenta", "cyan"]
344        self._p.palette(1,colours)
345        self._p.set_line(label='Spectrum')
346        self._p.plot(self.x, self.y, m)
347        if residual:
348            self._p.palette(2)
349            self._p.set_line(label='Residual')
350            self._p.plot(self.x, self.get_residual(), m)
351        self._p.palette(3)
352        if components is not None:
353            cs = components
354            if isinstance(components,int): cs = [components]
355            if plotparms:
356                self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
357            n = len(self.components)
358            self._p.palette(4)
359            for c in cs:
360                if 0 <= c < n:
361                    lab = self.fitfuncs[c]+str(c)
362                    self._p.set_line(label=lab)
363                    self._p.plot(self.x, self.fitter.evaluate(c), m)
364                elif c == -1:
365                    self._p.palette(3)
366                    self._p.set_line(label="Total Fit")
367                    self._p.plot(self.x, self.get_fit(), m)                   
368        else:
369            self._p.palette(3)
370            self._p.set_line(label='Fit')
371            self._p.plot(self.x, self.get_fit(), m)
372        self._p.set_axes('xlabel',xlab)
373        self._p.set_axes('ylabel',ylab)
374        self._p.set_axes('title',tlab)
375        self._p.release()
376
377    def auto_fit(self, insitu=None):
378        """
379        Return a scan where the function is applied to all rows for
380        all Beams/IFs/Pols.
381       
382        """
383        from asap import scantable
384        if not isinstance(self.data, scantable) :
385            print "Only works with scantables"
386            return
387        if insitu is None: insitu = rcParams['insitu']
388        if not insitu:
389            scan = self.data.copy()
390        else:
391            scan = self.data
392        vb = scan._vb
393        scan._vb = False
394        sel = scan.get_cursor()
395        rows = range(scan.nrow())
396        for i in range(scan.nbeam()):
397            scan.setbeam(i)
398            for j in range(scan.nif()):
399                scan.setif(j)
400                for k in range(scan.npol()):
401                    scan.setpol(k)
402                    if self._vb:
403                        print "Fitting:"
404                        print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
405                    for iRow in rows:
406                        self.x = scan._getabcissa(iRow)
407                        self.y = scan._getspectrum(iRow)
408                        self.data = None
409                        self.fit()                   
410                        x = self.get_parameters()
411                        scan._setspectrum(self.fitter.getresidual(),iRow)
412        scan.set_cursor(sel[0],sel[1],sel[2])
413        scan._vb = vb
414        if not insitu:
415            return scan
Note: See TracBrowser for help on using the repository browser.