source: trunk/python/asapfitter.py @ 302

Last change on this file since 302 was 302, checked in by kil064, 19 years ago

getmask -> _getmask

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 8.6 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5    """
6    The fitting class for ASAP.
7    """
8    def _verbose(self, *args):
9        """
10        Set stdout output.
11        """
12        if type(args[0]) is bool:
13            self._vb = args[0]
14            return
15        elif len(args) == 0:
16            return self._vb
17       
18    def __init__(self):
19        """
20        Create a fitter object. No state is set.
21        """
22        self.fitter = _asap.fitter()
23        self.x = None
24        self.y = None
25        self.mask = None
26        self.fitfunc = None
27        self.fitted = False
28        self.data = None
29        self._p = None
30        self._vb = True
31
32    def set_data(self, xdat, ydat, mask=None):
33        """
34        Set the absissa and ordinate for the fit. Also set the mask
35        indicationg valid points.
36        This can be used for data vectors retrieved from a scantable.
37        For scantable fitting use 'fitter.set_scan(scan, mask)'.
38        Parameters:
39            xdat:    the abcissa values
40            ydat:    the ordinate values
41            mask:    an optional mask
42       
43        """
44        self.fitted = False
45        self.x = xdat
46        self.y = ydat
47        if mask == None:
48            from numarray import ones
49            self.mask = ones(len(xdat))
50        else:
51            self.mask = mask
52        return
53
54    def set_scan(self, thescan=None, mask=None):
55        """
56        Set the 'data' (a scantable) of the fitter.
57        Parameters:
58            thescan:     a scantable
59            mask:        a msk retireved from the scantable
60        """
61        if not thescan:
62            print "Please give a correct scan"
63        self.fitted = False
64        self.data = thescan
65        if mask is None:
66            from numarray import ones
67            self.mask = ones(self.data.nchan())
68        else:
69            self.mask = mask
70        return
71
72    def set_function(self, **kwargs):
73        """
74        Set the function to be fit.
75        Parameters:
76            poly:    use a polynomial of the order given
77            gauss:   fit the number of gaussian specified
78        Example:
79            fitter.set_function(gauss=2) # will fit two gaussians
80            fitter.set_function(poly=3)  # will fit a 3rd order polynomial
81        """
82        #default poly order 0
83       
84
85        if kwargs.has_key('poly'):
86            self.fitfunc = 'poly'
87            n = kwargs.get('poly')
88        elif kwargs.has_key('gauss'):
89            n = kwargs.get('gauss')
90            self.fitfunc = 'gauss'
91       
92        self.fitter.setexpression(self.fitfunc,n)
93        return
94           
95    def fit(self):
96        """
97        Execute the actual fitting process. All the state has to be set.
98        Parameters:
99            none
100        Example:
101            s= scantable('myscan.asap')
102            f = fitter()
103            f.set_scan(s)
104            f.set_function(poly=0)
105            f.fit()
106        """
107        if ((self.x is None or self.y is None) and self.data is None) \
108               or self.fitfunc is None:
109            print "Fitter not yet initialised. Please set data & fit function"
110            return
111        else:
112            if self.data is not None:
113                self.x = self.data._getabcissa()
114                self.y = self.data._getspectrum()
115                print "Fitting:"
116                vb = self.data._vb
117                self.data._vb = True
118                s = self.data.get_cursor()
119                self.data._vb = vb
120       
121        self.fitter.setdata(self.x,self.y,self.mask)
122        if self.fitfunc == 'gauss':
123            ps = self.fitter.getparameters()
124            if len(ps) == 0:
125                self.fitter.estimate()
126        self.fitter.fit()
127        self.fitted = True
128        return
129
130    def set_parameters(self, params, fixed=None):
131        self.fitter.setparameters(params)
132        if fixed is not None:
133            self.fitter.setfixedparameters(fixed)
134        return
135   
136    def get_parameters(self):
137        """
138        Return the fit paramters.
139       
140        """
141        if not self.fitted:
142            print "Not yet fitted."
143        pars = list(self.fitter.getparameters())
144        fixed = list(self.fitter.getfixedparameters())
145        if self._vb:
146            print self._format_pars(pars)
147        return pars,fixed
148   
149    def _format_pars(self, pars):
150        out = ''
151        if self.fitfunc == 'poly':
152            c = 0
153            for i in pars:
154                out += '  p%d = %3.3f, ' % (c,i)
155                c+=1
156        elif self.fitfunc == 'gauss':
157            i = 0
158            c = 0
159            unit = ''
160            if self.data:
161                unit = self.data.get_unit()
162            while i < len(pars):
163                out += '  %d: peak = %3.3f , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],pars[i+1],unit,pars[i+2],unit)
164                c+=1
165                i+=3
166        return out
167       
168    def get_estimate(self):
169        """
170        Return the paramter estimates (for non-linear functions).
171        """
172        pars = self.fitter.getestimate()
173        if self._vb:
174            print self._format_pars(pars)
175        return pars
176       
177
178    def get_residual(self):
179        """
180        Return the residual of the fit.
181        """
182        if not self.fitted:
183            print "Not yet fitted."
184        return self.fitter.getresidual()
185
186    def get_chi2(self):
187        """
188        Return chi^2.
189        """
190       
191        if not self.fitted:
192            print "Not yet fitted."
193        ch2 = self.fitter.getchi2()
194        if self._vb:
195            print 'Chi^2 = %3.3f' % (ch2)
196        return ch2
197
198    def get_fit(self):
199        """
200        Return the fitted ordinate values.
201        """
202        if not self.fitted:
203            print "Not yet fitted."
204        return self.fitter.getfit()
205
206    def commit(self):
207        """
208        Return a new scan where teh fits have been commited.
209        """
210        if not self.fitted:
211            print "Not yet fitted."
212        if self.data is not scantable:
213            print "Only works with scantables"
214            return
215        scan = self.data.copy()
216        scan._setspectrum(self.fitter.getresidual())
217
218    def plot(self, residual=False):
219        """
220        Plot the last fit.
221        Parameters:
222            residual:    an optional parameter indicating if the residual
223                         should be plotted (default 'False')
224        """
225        if not self.fitted:
226            return
227        if not self._p:
228            from asap.asaplot import ASAPlot
229            self._p = ASAPlot()
230        if self._p.is_dead:
231            from asap.asaplot import ASAPlot
232            self._p = ASAPlot()
233        self._p.clear()
234        tlab = 'Spectrum'
235        xlab = 'Abcissa'
236        if self.data:
237            tlab = self.data._getsourcename(0)
238            xlab = self.data._getabcissalabel(0)
239        ylab = r'Flux'
240        m = self.data._getmask(0)
241        self._p.set_line(colour='blue',label='Spectrum')
242        self._p.plot(self.x, self.y, m)
243        if residual:
244            self._p.set_line(colour='green',label='Residual')
245            self._p.plot(self.x, self.get_residual(), m)
246        self._p.set_line(colour='red',label='Fit')
247        self._p.plot(self.x, self.get_fit(), m)
248       
249        self._p.set_axes('xlabel',xlab)
250        self._p.set_axes('ylabel',ylab)
251        self._p.set_axes('title',tlab)
252        self._p.release()
253
254
255    def auto_fit(self, insitu=None):
256        """
257        Return a scan where the function is applied to all rows for all Beams/IFs/Pols.
258       
259        """
260        from asap import scantable
261        if not isinstance(self.data,scantable) :
262            print "Only works with scantables"
263            return
264        if insitu is None: insitu = rcParams['insitu']
265        if not insitu:
266            scan = self.data.copy()
267        else:
268            scan = self.data
269        vb = scan._vb
270        scan._vb = False
271        sel = scan.get_cursor()
272        rows = range(scan.nrow())
273        for i in range(scan.nbeam()):
274            scan.setbeam(i)
275            for j in range(scan.nif()):
276                scan.setif(j)
277                for k in range(scan.npol()):
278                    scan.setpol(k)
279                    if self._vb:
280                        print "Fitting:"
281                        print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
282                    for iRow in rows:
283                        self.x = scan._getabcissa(iRow)
284                        self.y = scan._getspectrum(iRow)
285                        self.data = None
286                        self.fit()                   
287                        x = self.get_parameters()
288                        scan._setspectrum(self.fitter.getresidual(),iRow)
289        scan.set_cursor(sel[0],sel[1],sel[2])
290        scan._vb = vb
291        if not insitu:
292            return scan
Note: See TracBrowser for help on using the repository browser.