source: trunk/python/asapfitter.py @ 652

Last change on this file since 652 was 652, checked in by mar637, 19 years ago

removed color loading as mpl now supports named colors. some minor corrections on pol label handling. Also added orientation option for ps output.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.0 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5    """
6    The fitting class for ASAP.
7    """
8    def _verbose(self, *args):
9        """
10        Set stdout output.
11        """
12        if type(args[0]) is bool:
13            self._vb = args[0]
14            return
15        elif len(args) == 0:
16            return self._vb
17       
18    def __init__(self):
19        """
20        Create a fitter object. No state is set.
21        """
22        self.fitter = _asap.fitter()
23        self.x = None
24        self.y = None
25        self.mask = None
26        self.fitfunc = None
27        self.fitfuncs = None
28        self.fitted = False
29        self.data = None
30        self.components = 0
31        self._fittedrow = 0
32        self._p = None
33        self._vb = True
34        self._selection = None
35
36    def set_data(self, xdat, ydat, mask=None):
37        """
38        Set the absissa and ordinate for the fit. Also set the mask
39        indicationg valid points.
40        This can be used for data vectors retrieved from a scantable.
41        For scantable fitting use 'fitter.set_scan(scan, mask)'.
42        Parameters:
43            xdat:    the abcissa values
44            ydat:    the ordinate values
45            mask:    an optional mask
46       
47        """
48        self.fitted = False
49        self.x = xdat
50        self.y = ydat
51        if mask == None:
52            from numarray import ones
53            self.mask = ones(len(xdat))
54        else:
55            self.mask = mask
56        return
57
58    def set_scan(self, thescan=None, mask=None):
59        """
60        Set the 'data' (a scantable) of the fitter.
61        Parameters:
62            thescan:     a scantable
63            mask:        a msk retireved from the scantable
64        """
65        if not thescan:
66            print "Please give a correct scan"
67        self.fitted = False
68        self.data = thescan
69        if mask is None:
70            from numarray import ones
71            self.mask = ones(self.data.nchan())
72        else:
73            self.mask = mask
74        return
75
76    def set_function(self, **kwargs):
77        """
78        Set the function to be fit.
79        Parameters:
80            poly:    use a polynomial of the order given
81            gauss:   fit the number of gaussian specified
82        Example:
83            fitter.set_function(gauss=2) # will fit two gaussians
84            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
85        """
86        #default poly order 0       
87        n=0
88        if kwargs.has_key('poly'):
89            self.fitfunc = 'poly'
90            n = kwargs.get('poly')
91            self.components = [n]
92        elif kwargs.has_key('gauss'):
93            n = kwargs.get('gauss')
94            self.fitfunc = 'gauss'
95            self.fitfuncs = [ 'gauss' for i in range(n) ]
96            self.components = [ 3 for i in range(n) ]
97        else:
98            print "Invalid function type."
99            return
100        self.fitter.setexpression(self.fitfunc,n)
101        return
102           
103    def fit(self, row=0):
104        """
105        Execute the actual fitting process. All the state has to be set.
106        Parameters:
107            row:    specify the row in the scantable
108        Example:
109            s = scantable('myscan.asap')
110            s.set_cursor(thepol=1)        # select second pol
111            f = fitter()
112            f.set_scan(s)
113            f.set_function(poly=0)
114            f.fit(row=0)                  # fit first row
115        """
116        if ((self.x is None or self.y is None) and self.data is None) \
117               or self.fitfunc is None:
118            print "Fitter not yet initialised. Please set data & fit function"
119            return
120        else:
121            if self.data is not None:
122                self.x = self.data._getabcissa(row)
123                self.y = self.data._getspectrum(row)
124                print "Fitting:"
125                vb = self.data._vb
126                self.data._vb = True
127                self.selection = self.data.get_cursor()
128                self.data._vb = vb
129        self.fitter.setdata(self.x, self.y, self.mask)
130        if self.fitfunc == 'gauss':
131            ps = self.fitter.getparameters()
132            if len(ps) == 0:
133                self.fitter.estimate()
134        try:
135            self.fitter.fit()
136        except RuntimeError, msg:
137            print msg           
138        self._fittedrow = row
139        self.fitted = True
140        return
141
142    def store_fit(self):
143        """
144        Store the fit parameters in the scantable.
145        """
146        if self.fitted and self.data is not None:
147            pars = list(self.fitter.getparameters())
148            fixed = list(self.fitter.getfixedparameters())
149            self.data._addfit(self._fittedrow, pars, fixed,
150                              self.fitfuncs, self.components)
151
152    def set_parameters(self, params, fixed=None, component=None):
153        """
154        Set the parameters to be fitted.
155        Parameters:
156              params:    a vector of parameters
157              fixed:     a vector of which parameters are to be held fixed
158                         (default is none)
159              component: in case of multiple gaussians, the index of the
160                         component
161             """
162        if self.fitfunc is None:
163            print "Please specify a fitting function first."
164            return
165        if self.fitfunc == "gauss" and component is not None:
166            if not self.fitted:
167                from numarray import zeros
168                pars = list(zeros(len(self.components)*3))
169                fxd = list(zeros(len(pars)))
170            else:
171                pars = list(self.fitter.getparameters())             
172                fxd = list(self.fitter.getfixedparameters())
173            i = 3*component
174            pars[i:i+3] = params
175            fxd[i:i+3] = fixed
176            params = pars
177            fixed = fxd         
178        self.fitter.setparameters(params)
179        if fixed is not None:
180            self.fitter.setfixedparameters(fixed)
181        return
182
183    def set_gauss_parameters(self, peak, centre, fhwm,
184                             peakfixed=False, centerfixed=False,
185                             fhwmfixed=False,
186                             component=0):
187        """
188        Set the Parameters of a 'Gaussian' component, set with set_function.
189        Parameters:
190            peak, centre, fhwm:  The gaussian parameters
191            peakfixed,
192            centerfixed,
193            fhwmfixed:           Optional parameters to indicate if
194                                 the paramters should be held fixed during
195                                 the fitting process. The default is to keep
196                                 all parameters flexible.
197            component:           The number of the component (Default is the
198                                 component 0)
199        """
200        if self.fitfunc != "gauss":
201            print "Function only operates on Gaussian components."
202            return
203        if 0 <= component < len(self.components):
204            self.set_parameters([peak, centre, fhwm],
205                                [peakfixed, centerfixed, fhwmfixed],
206                                component)
207        else:
208            print "Please select a valid  component."
209            return
210       
211    def get_parameters(self, component=None):
212        """
213        Return the fit paramters.
214        Parameters:
215             component:    get the parameters for the specified component
216                           only, default is all components
217        """
218        if not self.fitted:
219            print "Not yet fitted."
220        pars = list(self.fitter.getparameters())
221        fixed = list(self.fitter.getfixedparameters())
222        if component is not None:           
223            if self.fitfunc == "gauss":
224                i = 3*component
225                cpars = pars[i:i+3]
226                cfixed = fixed[i:i+3]
227            else:
228                cpars = pars
229                cfixed = fixed               
230        else:
231            cpars = pars
232            cfixed = fixed
233        fpars = self._format_pars(cpars, cfixed)
234        if self._vb:
235            print fpars
236        return cpars, cfixed, fpars
237   
238    def _format_pars(self, pars, fixed):
239        out = ''
240        if self.fitfunc == 'poly':
241            c = 0
242            for i in range(len(pars)):
243                fix = ""
244                if fixed[i]: fix = "(fixed)"
245                out += '  p%d%s= %3.3f,' % (c,fix,pars[i])
246                c+=1
247            out = out[:-1]  # remove trailing ','
248        elif self.fitfunc == 'gauss':
249            i = 0
250            c = 0
251            aunit = ''
252            ounit = ''
253            if self.data:
254                aunit = self.data.get_unit()
255                ounit = self.data.get_fluxunit()
256            while i < len(pars):
257                out += '  %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
258                c+=1
259                i+=3
260        return out
261       
262    def get_estimate(self):
263        """
264        Return the parameter estimates (for non-linear functions).
265        """
266        pars = self.fitter.getestimate()
267        if self._vb:
268            print self._format_pars(pars)
269        return pars
270       
271
272    def get_residual(self):
273        """
274        Return the residual of the fit.
275        """
276        if not self.fitted:
277            print "Not yet fitted."
278        return self.fitter.getresidual()
279
280    def get_chi2(self):
281        """
282        Return chi^2.
283        """
284        if not self.fitted:
285            print "Not yet fitted."
286        ch2 = self.fitter.getchi2()
287        if self._vb:
288            print 'Chi^2 = %3.3f' % (ch2)
289        return ch2
290
291    def get_fit(self):
292        """
293        Return the fitted ordinate values.
294        """
295        if not self.fitted:
296            print "Not yet fitted."
297        return self.fitter.getfit()
298
299    def commit(self):
300        """
301        Return a new scan where the fits have been commited (subtracted)
302        """
303        if not self.fitted:
304            print "Not yet fitted."
305        if self.data is not scantable:
306            print "Only works with scantables"
307            return
308        scan = self.data.copy()
309        scan._setspectrum(self.fitter.getresidual())
310
311    def plot(self, residual=False, components=None, plotparms=False):
312        """
313        Plot the last fit.
314        Parameters:
315            residual:    an optional parameter indicating if the residual
316                         should be plotted (default 'False')
317            components:  a list of components to plot, e.g [0,1],
318                         -1 plots the total fit. Default is to only
319                         plot the total fit.
320            plotparms:   Inidicates if the parameter values should be present
321                         on the plot
322        """
323        if not self.fitted:
324            return
325        if not self._p:
326            from asap.asaplot import ASAPlot
327            self._p = ASAPlot()
328        if self._p.is_dead:
329            from asap.asaplot import ASAPlot
330            self._p = ASAPlot()
331        self._p.clear()
332        self._p.set_panels()
333        self._p.palette(0)
334        tlab = 'Spectrum'
335        xlab = 'Abcissa'       
336        m = ()
337        if self.data:
338            tlab = self.data._getsourcename(self._fittedrow)
339            xlab = self.data._getabcissalabel(self._fittedrow)
340            m = self.data._getmask(self._fittedrow)
341            ylab = self.data._get_ordinate_label()
342
343        colours = ["grey60","grey80","red","orange","purple","green","magenta", "cyan"]
344        self._p.palette(0,colours)
345        self._p.set_line(label='Spectrum')
346        self._p.plot(self.x, self.y, m)
347        if residual:
348            self._p.palette(1)
349            self._p.set_line(label='Residual')
350            self._p.plot(self.x, self.get_residual(), m)
351        self._p.palette(2)
352        if components is not None:
353            cs = components
354            if isinstance(components,int): cs = [components]
355            if plotparms:
356                self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
357            n = len(self.components)
358            self._p.palette(3)
359            for c in cs:
360                if 0 <= c < n:
361                    lab = self.fitfuncs[c]+str(c)
362                    self._p.set_line(label=lab)
363                    self._p.plot(self.x, self.fitter.evaluate(c), m)
364                elif c == -1:
365                    self._p.palette(2)
366                    self._p.set_line(label="Total Fit")
367                    self._p.plot(self.x, self.get_fit(), m)                   
368        else:
369            self._p.palette(2)
370            self._p.set_line(label='Fit')
371            self._p.plot(self.x, self.get_fit(), m)
372        self._p.set_axes('xlabel',xlab)
373        self._p.set_axes('ylabel',ylab)
374        self._p.set_axes('title',tlab)
375        self._p.release()
376
377    def auto_fit(self, insitu=None):
378        """
379        Return a scan where the function is applied to all rows for
380        all Beams/IFs/Pols.
381       
382        """
383        from asap import scantable
384        if not isinstance(self.data, scantable) :
385            print "Only works with scantables"
386            return
387        if insitu is None: insitu = rcParams['insitu']
388        if not insitu:
389            scan = self.data.copy()
390        else:
391            scan = self.data
392        vb = scan._vb
393        scan._vb = False
394        sel = scan.get_cursor()
395        rows = range(scan.nrow())
396        for i in range(scan.nbeam()):
397            scan.setbeam(i)
398            for j in range(scan.nif()):
399                scan.setif(j)
400                for k in range(scan.npol()):
401                    scan.setpol(k)
402                    if self._vb:
403                        print "Fitting:"
404                        print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
405                    for iRow in rows:
406                        self.x = scan._getabcissa(iRow)
407                        self.y = scan._getspectrum(iRow)
408                        self.data = None
409                        self.fit()                   
410                        x = self.get_parameters()
411                        scan._setspectrum(self.fitter.getresidual(),iRow)
412        scan.set_cursor(sel[0],sel[1],sel[2])
413        scan._vb = vb
414        if not insitu:
415            return scan
Note: See TracBrowser for help on using the repository browser.