source: trunk/python/asapfitter.py@ 520

Last change on this file since 520 was 515, checked in by mar637, 20 years ago
  • major rework on plotting.
  • added component selection and plotting
  • added wrapper function to set parameters
  • addedd formatting of parameter print out
  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.0 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5 """
6 The fitting class for ASAP.
7 """
8 def _verbose(self, *args):
9 """
10 Set stdout output.
11 """
12 if type(args[0]) is bool:
13 self._vb = args[0]
14 return
15 elif len(args) == 0:
16 return self._vb
17
18 def __init__(self):
19 """
20 Create a fitter object. No state is set.
21 """
22 self.fitter = _asap.fitter()
23 self.x = None
24 self.y = None
25 self.mask = None
26 self.fitfunc = None
27 self.fitfuncs = None
28 self.fitted = False
29 self.data = None
30 self.components = 0
31 self._fittedrow = 0
32 self._p = None
33 self._vb = True
34 self._selection = None
35
36 def set_data(self, xdat, ydat, mask=None):
37 """
38 Set the absissa and ordinate for the fit. Also set the mask
39 indicationg valid points.
40 This can be used for data vectors retrieved from a scantable.
41 For scantable fitting use 'fitter.set_scan(scan, mask)'.
42 Parameters:
43 xdat: the abcissa values
44 ydat: the ordinate values
45 mask: an optional mask
46
47 """
48 self.fitted = False
49 self.x = xdat
50 self.y = ydat
51 if mask == None:
52 from numarray import ones
53 self.mask = ones(len(xdat))
54 else:
55 self.mask = mask
56 return
57
58 def set_scan(self, thescan=None, mask=None):
59 """
60 Set the 'data' (a scantable) of the fitter.
61 Parameters:
62 thescan: a scantable
63 mask: a msk retireved from the scantable
64 """
65 if not thescan:
66 print "Please give a correct scan"
67 self.fitted = False
68 self.data = thescan
69 if mask is None:
70 from numarray import ones
71 self.mask = ones(self.data.nchan())
72 else:
73 self.mask = mask
74 return
75
76 def set_function(self, **kwargs):
77 """
78 Set the function to be fit.
79 Parameters:
80 poly: use a polynomial of the order given
81 gauss: fit the number of gaussian specified
82 Example:
83 fitter.set_function(gauss=2) # will fit two gaussians
84 fitter.set_function(poly=3) # will fit a 3rd order polynomial
85 """
86 #default poly order 0
87 n=0
88 if kwargs.has_key('poly'):
89 self.fitfunc = 'poly'
90 n = kwargs.get('poly')
91 self.components = [n]
92 elif kwargs.has_key('gauss'):
93 n = kwargs.get('gauss')
94 self.fitfunc = 'gauss'
95 self.fitfuncs = [ 'gauss' for i in range(n) ]
96 self.components = [ 3 for i in range(n) ]
97 else:
98 print "Invalid function type."
99 return
100 self.fitter.setexpression(self.fitfunc,n)
101 return
102
103 def fit(self, row=0):
104 """
105 Execute the actual fitting process. All the state has to be set.
106 Parameters:
107 none
108 Example:
109 s = scantable('myscan.asap')
110 s.set_cursor(thepol=1) # select second pol
111 f = fitter()
112 f.set_scan(s)
113 f.set_function(poly=0)
114 f.fit(row=0) # fit first row
115 """
116 if ((self.x is None or self.y is None) and self.data is None) \
117 or self.fitfunc is None:
118 print "Fitter not yet initialised. Please set data & fit function"
119 return
120 else:
121 if self.data is not None:
122 self.x = self.data._getabcissa(row)
123 self.y = self.data._getspectrum(row)
124 print "Fitting:"
125 vb = self.data._vb
126 self.data._vb = True
127 self.selection = self.data.get_cursor()
128 self.data._vb = vb
129 self.fitter.setdata(self.x, self.y, self.mask)
130 if self.fitfunc == 'gauss':
131 ps = self.fitter.getparameters()
132 if len(ps) == 0:
133 self.fitter.estimate()
134 self.fitter.fit()
135 self._fittedrow = row
136 self.fitted = True
137 return
138
139 def store_fit(self):
140 if self.fitted and self.data is not None:
141 pars = list(self.fitter.getparameters())
142 fixed = list(self.fitter.getfixedparameters())
143 self.data._addfit(self._fittedrow, pars, fixed,
144 self.fitfuncs, self.components)
145
146 def set_parameters(self, params, fixed=None, component=None):
147 if self.fitfunc is None:
148 print "Please specify a fitting function first."
149 return
150 if self.fitfunc == "gauss" and component is not None:
151 if not self.fitted:
152 from numarray import zeros
153 pars = list(zeros(len(self.components)*3))
154 fxd = list(zeros(len(pars)))
155 else:
156 pars = list(self.fitter.getparameters())
157 fxd = list(self.fitter.getfixedparameters())
158 i = 3*component
159 pars[i:i+3] = params
160 fxd[i:i+3] = fixed
161 params = pars
162 fixed = fxd
163 self.fitter.setparameters(params)
164 if fixed is not None:
165 self.fitter.setfixedparameters(fixed)
166 return
167
168 def set_gauss_parameters(self, peak, centre, fhwm,
169 peakfixed=False, centerfixed=False,
170 fhwmfixed=False,
171 component=0):
172 """
173 Set the Parameters of a 'Gaussian' component, set with set_function.
174 Parameters:
175 component: The number of the component (Default is the
176 first component.
177 peak, centre, fhwm: The gaussian parameters
178 peakfixed,
179 centerfixed,
180 fhwmfixed: Optional parameters to indicate if
181 the paramters should be held fixed during
182 the fitting process. The default is to keep
183 all parameters flexible.
184 """
185 if self.fitfunc != "gauss":
186 print "Function only operates on Gaussian components."
187 return
188 if 0 <= component < len(self.components):
189 self.set_parameters([peak, centre, fhwm],
190 [peakfixed, centerfixed, fhwmfixed],
191 component)
192 else:
193 print "Please select a valid component."
194 return
195
196 def get_parameters(self, component=None):
197 """
198 Return the fit paramters.
199
200 """
201 if not self.fitted:
202 print "Not yet fitted."
203 pars = list(self.fitter.getparameters())
204 fixed = list(self.fitter.getfixedparameters())
205 if component is not None:
206 if self.fitfunc == "gauss":
207 i = 3*component
208 cpars = pars[i:i+3]
209 cfixed = fixed[i:i+3]
210 else:
211 cpars = pars
212 cfixed = fixed
213 else:
214 cpars = pars
215 cfixed = fixed
216 fpars = self._format_pars(cpars, cfixed)
217 if self._vb:
218 print fpars
219 return cpars, cfixed, fpars
220
221 def _format_pars(self, pars, fixed):
222 out = ''
223 if self.fitfunc == 'poly':
224 c = 0
225 for i in range(len(pars)):
226 fix = ""
227 if fixed[i]: fix = "(fixed)"
228 out += ' p%d%s= %3.3f,' % (c,fix,pars[i])
229 c+=1
230 out = out[:-1] # remove trailing ','
231 elif self.fitfunc == 'gauss':
232 i = 0
233 c = 0
234 aunit = ''
235 ounit = ''
236 if self.data:
237 aunit = self.data.get_unit()
238 ounit = self.data.get_fluxunit()
239 while i < len(pars):
240 out += ' %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
241 c+=1
242 i+=3
243 return out
244
245 def get_estimate(self):
246 """
247 Return the parameter estimates (for non-linear functions).
248 """
249 pars = self.fitter.getestimate()
250 if self._vb:
251 print self._format_pars(pars)
252 return pars
253
254
255 def get_residual(self):
256 """
257 Return the residual of the fit.
258 """
259 if not self.fitted:
260 print "Not yet fitted."
261 return self.fitter.getresidual()
262
263 def get_chi2(self):
264 """
265 Return chi^2.
266 """
267 if not self.fitted:
268 print "Not yet fitted."
269 ch2 = self.fitter.getchi2()
270 if self._vb:
271 print 'Chi^2 = %3.3f' % (ch2)
272 return ch2
273
274 def get_fit(self):
275 """
276 Return the fitted ordinate values.
277 """
278 if not self.fitted:
279 print "Not yet fitted."
280 return self.fitter.getfit()
281
282 def commit(self):
283 """
284 Return a new scan where the fits have been commited.
285 """
286 if not self.fitted:
287 print "Not yet fitted."
288 if self.data is not scantable:
289 print "Only works with scantables"
290 return
291 scan = self.data.copy()
292 scan._setspectrum(self.fitter.getresidual())
293
294 def plot(self, residual=False, components=None, plotparms=False,
295 plotrange=None):
296 """
297 Plot the last fit.
298 Parameters:
299 residual: an optional parameter indicating if the residual
300 should be plotted (default 'False')
301 """
302 if not self.fitted:
303 return
304 if not self._p:
305 from asap.asaplot import ASAPlot
306 self._p = ASAPlot()
307 if self._p.is_dead:
308 from asap.asaplot import ASAPlot
309 self._p = ASAPlot()
310 self._p.clear()
311 self._p.set_panels()
312 self._p.palette(1)
313 tlab = 'Spectrum'
314 xlab = 'Abcissa'
315 m = ()
316 if self.data:
317 tlab = self.data._getsourcename(self._fittedrow)
318 xlab = self.data._getabcissalabel(self._fittedrow)
319 m = self.data._getmask(self._fittedrow)
320 ylab = r'Flux'
321
322 colours = ["grey60","grey80","red","orange","purple","yellow","magenta", "cyan"]
323 self._p.palette(1,colours)
324 self._p.set_line(label='Spectrum')
325 self._p.plot(self.x, self.y, m)
326 if residual:
327 self._p.palette(2)
328 self._p.set_line(label='Residual')
329 self._p.plot(self.x, self.get_residual(), m)
330 self._p.palette(3)
331 if components is not None:
332 cs = components
333 if isinstance(components,int): cs = [components]
334 self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
335 n = len(self.components)
336 self._p.palette(4)
337 for c in cs:
338 if 0 <= c < n:
339 lab = self.fitfuncs[c]+str(c)
340 self._p.set_line(label=lab)
341 self._p.plot(self.x, self.fitter.evaluate(c), m)
342 elif c == -1:
343 self._p.palette(3)
344 self._p.set_line(label="Total Fit")
345 self._p.plot(self.x, self.get_fit(), m)
346 else:
347 self._p.palette(3)
348 self._p.set_line(label='Fit')
349 self._p.plot(self.x, self.get_fit(), m)
350 self._p.set_axes('xlabel',xlab)
351 self._p.set_axes('ylabel',ylab)
352 self._p.set_axes('title',tlab)
353 self._p.release()
354
355 def auto_fit(self, insitu=None):
356 """
357 Return a scan where the function is applied to all rows for
358 all Beams/IFs/Pols.
359
360 """
361 from asap import scantable
362 if not isinstance(self.data, scantable) :
363 print "Only works with scantables"
364 return
365 if insitu is None: insitu = rcParams['insitu']
366 if not insitu:
367 scan = self.data.copy()
368 else:
369 scan = self.data
370 vb = scan._vb
371 scan._vb = False
372 sel = scan.get_cursor()
373 rows = range(scan.nrow())
374 for i in range(scan.nbeam()):
375 scan.setbeam(i)
376 for j in range(scan.nif()):
377 scan.setif(j)
378 for k in range(scan.npol()):
379 scan.setpol(k)
380 if self._vb:
381 print "Fitting:"
382 print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
383 for iRow in rows:
384 self.x = scan._getabcissa(iRow)
385 self.y = scan._getspectrum(iRow)
386 self.data = None
387 self.fit()
388 x = self.get_parameters()
389 scan._setspectrum(self.fitter.getresidual(),iRow)
390 scan.set_cursor(sel[0],sel[1],sel[2])
391 scan._vb = vb
392 if not insitu:
393 return scan
Note: See TracBrowser for help on using the repository browser.