source: branches/Release-1-fixes/python/asapfitter.py@ 595

Last change on this file since 595 was 526, checked in by mar637, 20 years ago

added more help

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 13.9 KB
Line 
1import _asap
2from asap import rcParams
3
4class fitter:
5 """
6 The fitting class for ASAP.
7 """
8 def _verbose(self, *args):
9 """
10 Set stdout output.
11 """
12 if type(args[0]) is bool:
13 self._vb = args[0]
14 return
15 elif len(args) == 0:
16 return self._vb
17
18 def __init__(self):
19 """
20 Create a fitter object. No state is set.
21 """
22 self.fitter = _asap.fitter()
23 self.x = None
24 self.y = None
25 self.mask = None
26 self.fitfunc = None
27 self.fitfuncs = None
28 self.fitted = False
29 self.data = None
30 self.components = 0
31 self._fittedrow = 0
32 self._p = None
33 self._vb = True
34 self._selection = None
35
36 def set_data(self, xdat, ydat, mask=None):
37 """
38 Set the absissa and ordinate for the fit. Also set the mask
39 indicationg valid points.
40 This can be used for data vectors retrieved from a scantable.
41 For scantable fitting use 'fitter.set_scan(scan, mask)'.
42 Parameters:
43 xdat: the abcissa values
44 ydat: the ordinate values
45 mask: an optional mask
46
47 """
48 self.fitted = False
49 self.x = xdat
50 self.y = ydat
51 if mask == None:
52 from numarray import ones
53 self.mask = ones(len(xdat))
54 else:
55 self.mask = mask
56 return
57
58 def set_scan(self, thescan=None, mask=None):
59 """
60 Set the 'data' (a scantable) of the fitter.
61 Parameters:
62 thescan: a scantable
63 mask: a msk retireved from the scantable
64 """
65 if not thescan:
66 print "Please give a correct scan"
67 self.fitted = False
68 self.data = thescan
69 if mask is None:
70 from numarray import ones
71 self.mask = ones(self.data.nchan())
72 else:
73 self.mask = mask
74 return
75
76 def set_function(self, **kwargs):
77 """
78 Set the function to be fit.
79 Parameters:
80 poly: use a polynomial of the order given
81 gauss: fit the number of gaussian specified
82 Example:
83 fitter.set_function(gauss=2) # will fit two gaussians
84 fitter.set_function(poly=3) # will fit a 3rd order polynomial
85 """
86 #default poly order 0
87 n=0
88 if kwargs.has_key('poly'):
89 self.fitfunc = 'poly'
90 n = kwargs.get('poly')
91 self.components = [n]
92 elif kwargs.has_key('gauss'):
93 n = kwargs.get('gauss')
94 self.fitfunc = 'gauss'
95 self.fitfuncs = [ 'gauss' for i in range(n) ]
96 self.components = [ 3 for i in range(n) ]
97 else:
98 print "Invalid function type."
99 return
100 self.fitter.setexpression(self.fitfunc,n)
101 return
102
103 def fit(self, row=0):
104 """
105 Execute the actual fitting process. All the state has to be set.
106 Parameters:
107 row: specify the row in the scantable
108 Example:
109 s = scantable('myscan.asap')
110 s.set_cursor(thepol=1) # select second pol
111 f = fitter()
112 f.set_scan(s)
113 f.set_function(poly=0)
114 f.fit(row=0) # fit first row
115 """
116 if ((self.x is None or self.y is None) and self.data is None) \
117 or self.fitfunc is None:
118 print "Fitter not yet initialised. Please set data & fit function"
119 return
120 else:
121 if self.data is not None:
122 self.x = self.data._getabcissa(row)
123 self.y = self.data._getspectrum(row)
124 print "Fitting:"
125 vb = self.data._vb
126 self.data._vb = True
127 self.selection = self.data.get_cursor()
128 self.data._vb = vb
129 self.fitter.setdata(self.x, self.y, self.mask)
130 if self.fitfunc == 'gauss':
131 ps = self.fitter.getparameters()
132 if len(ps) == 0:
133 self.fitter.estimate()
134 self.fitter.fit()
135 self._fittedrow = row
136 self.fitted = True
137 return
138
139 def store_fit(self):
140 """
141 Store the fit parameters in the scantable.
142 """
143 if self.fitted and self.data is not None:
144 pars = list(self.fitter.getparameters())
145 fixed = list(self.fitter.getfixedparameters())
146 self.data._addfit(self._fittedrow, pars, fixed,
147 self.fitfuncs, self.components)
148
149 def set_parameters(self, params, fixed=None, component=None):
150 """
151 Set the parameters to be fitted.
152 Parameters:
153 params: a vector of parameters
154 fixed: a vector of which parameters are to be held fixed
155 (default is none)
156 component: in case of multiple gaussians, the index of the
157 component
158 """
159 if self.fitfunc is None:
160 print "Please specify a fitting function first."
161 return
162 if self.fitfunc == "gauss" and component is not None:
163 if not self.fitted:
164 from numarray import zeros
165 pars = list(zeros(len(self.components)*3))
166 fxd = list(zeros(len(pars)))
167 else:
168 pars = list(self.fitter.getparameters())
169 fxd = list(self.fitter.getfixedparameters())
170 i = 3*component
171 pars[i:i+3] = params
172 fxd[i:i+3] = fixed
173 params = pars
174 fixed = fxd
175 self.fitter.setparameters(params)
176 if fixed is not None:
177 self.fitter.setfixedparameters(fixed)
178 return
179
180 def set_gauss_parameters(self, peak, centre, fhwm,
181 peakfixed=False, centerfixed=False,
182 fhwmfixed=False,
183 component=0):
184 """
185 Set the Parameters of a 'Gaussian' component, set with set_function.
186 Parameters:
187 peak, centre, fhwm: The gaussian parameters
188 peakfixed,
189 centerfixed,
190 fhwmfixed: Optional parameters to indicate if
191 the paramters should be held fixed during
192 the fitting process. The default is to keep
193 all parameters flexible.
194 component: The number of the component (Default is the
195 component 0)
196 """
197 if self.fitfunc != "gauss":
198 print "Function only operates on Gaussian components."
199 return
200 if 0 <= component < len(self.components):
201 self.set_parameters([peak, centre, fhwm],
202 [peakfixed, centerfixed, fhwmfixed],
203 component)
204 else:
205 print "Please select a valid component."
206 return
207
208 def get_parameters(self, component=None):
209 """
210 Return the fit paramters.
211 Parameters:
212 component: get the parameters for the specified component
213 only, default is all components
214 """
215 if not self.fitted:
216 print "Not yet fitted."
217 pars = list(self.fitter.getparameters())
218 fixed = list(self.fitter.getfixedparameters())
219 if component is not None:
220 if self.fitfunc == "gauss":
221 i = 3*component
222 cpars = pars[i:i+3]
223 cfixed = fixed[i:i+3]
224 else:
225 cpars = pars
226 cfixed = fixed
227 else:
228 cpars = pars
229 cfixed = fixed
230 fpars = self._format_pars(cpars, cfixed)
231 if self._vb:
232 print fpars
233 return cpars, cfixed, fpars
234
235 def _format_pars(self, pars, fixed):
236 out = ''
237 if self.fitfunc == 'poly':
238 c = 0
239 for i in range(len(pars)):
240 fix = ""
241 if fixed[i]: fix = "(fixed)"
242 out += ' p%d%s= %3.3f,' % (c,fix,pars[i])
243 c+=1
244 out = out[:-1] # remove trailing ','
245 elif self.fitfunc == 'gauss':
246 i = 0
247 c = 0
248 aunit = ''
249 ounit = ''
250 if self.data:
251 aunit = self.data.get_unit()
252 ounit = self.data.get_fluxunit()
253 while i < len(pars):
254 out += ' %d: peak = %3.3f %s , centre = %3.3f %s, FWHM = %3.3f %s \n' % (c,pars[i],ounit,pars[i+1],aunit,pars[i+2],aunit)
255 c+=1
256 i+=3
257 return out
258
259 def get_estimate(self):
260 """
261 Return the parameter estimates (for non-linear functions).
262 """
263 pars = self.fitter.getestimate()
264 if self._vb:
265 print self._format_pars(pars)
266 return pars
267
268
269 def get_residual(self):
270 """
271 Return the residual of the fit.
272 """
273 if not self.fitted:
274 print "Not yet fitted."
275 return self.fitter.getresidual()
276
277 def get_chi2(self):
278 """
279 Return chi^2.
280 """
281 if not self.fitted:
282 print "Not yet fitted."
283 ch2 = self.fitter.getchi2()
284 if self._vb:
285 print 'Chi^2 = %3.3f' % (ch2)
286 return ch2
287
288 def get_fit(self):
289 """
290 Return the fitted ordinate values.
291 """
292 if not self.fitted:
293 print "Not yet fitted."
294 return self.fitter.getfit()
295
296 def commit(self):
297 """
298 Return a new scan where the fits have been commited (subtracted)
299 """
300 if not self.fitted:
301 print "Not yet fitted."
302 if self.data is not scantable:
303 print "Only works with scantables"
304 return
305 scan = self.data.copy()
306 scan._setspectrum(self.fitter.getresidual())
307
308 def plot(self, residual=False, components=None, plotparms=False):
309 """
310 Plot the last fit.
311 Parameters:
312 residual: an optional parameter indicating if the residual
313 should be plotted (default 'False')
314 components: a list of components to plot, e.g [0,1],
315 -1 plots the total fit. Default is to only
316 plot the total fit.
317 plotparms: Inidicates if the parameter values should be present
318 on the plot
319 """
320 if not self.fitted:
321 return
322 if not self._p:
323 from asap.asaplot import ASAPlot
324 self._p = ASAPlot()
325 if self._p.is_dead:
326 from asap.asaplot import ASAPlot
327 self._p = ASAPlot()
328 self._p.clear()
329 self._p.set_panels()
330 self._p.palette(1)
331 tlab = 'Spectrum'
332 xlab = 'Abcissa'
333 m = ()
334 if self.data:
335 tlab = self.data._getsourcename(self._fittedrow)
336 xlab = self.data._getabcissalabel(self._fittedrow)
337 m = self.data._getmask(self._fittedrow)
338 ylab = r'Flux'
339
340 colours = ["grey60","grey80","red","orange","purple","yellow","magenta", "cyan"]
341 self._p.palette(1,colours)
342 self._p.set_line(label='Spectrum')
343 self._p.plot(self.x, self.y, m)
344 if residual:
345 self._p.palette(2)
346 self._p.set_line(label='Residual')
347 self._p.plot(self.x, self.get_residual(), m)
348 self._p.palette(3)
349 if components is not None:
350 cs = components
351 if isinstance(components,int): cs = [components]
352 if plotparms:
353 self._p.text(0.15,0.15,str(self.get_parameters()[2]),size=8)
354 n = len(self.components)
355 self._p.palette(4)
356 for c in cs:
357 if 0 <= c < n:
358 lab = self.fitfuncs[c]+str(c)
359 self._p.set_line(label=lab)
360 self._p.plot(self.x, self.fitter.evaluate(c), m)
361 elif c == -1:
362 self._p.palette(3)
363 self._p.set_line(label="Total Fit")
364 self._p.plot(self.x, self.get_fit(), m)
365 else:
366 self._p.palette(3)
367 self._p.set_line(label='Fit')
368 self._p.plot(self.x, self.get_fit(), m)
369 self._p.set_axes('xlabel',xlab)
370 self._p.set_axes('ylabel',ylab)
371 self._p.set_axes('title',tlab)
372 self._p.release()
373
374 def auto_fit(self, insitu=None):
375 """
376 Return a scan where the function is applied to all rows for
377 all Beams/IFs/Pols.
378
379 """
380 from asap import scantable
381 if not isinstance(self.data, scantable) :
382 print "Only works with scantables"
383 return
384 if insitu is None: insitu = rcParams['insitu']
385 if not insitu:
386 scan = self.data.copy()
387 else:
388 scan = self.data
389 vb = scan._vb
390 scan._vb = False
391 sel = scan.get_cursor()
392 rows = range(scan.nrow())
393 for i in range(scan.nbeam()):
394 scan.setbeam(i)
395 for j in range(scan.nif()):
396 scan.setif(j)
397 for k in range(scan.npol()):
398 scan.setpol(k)
399 if self._vb:
400 print "Fitting:"
401 print 'Beam[%d], IF[%d], Pol[%d]' % (i,j,k)
402 for iRow in rows:
403 self.x = scan._getabcissa(iRow)
404 self.y = scan._getspectrum(iRow)
405 self.data = None
406 self.fit()
407 x = self.get_parameters()
408 scan._setspectrum(self.fitter.getresidual(),iRow)
409 scan.set_cursor(sel[0],sel[1],sel[2])
410 scan._vb = vb
411 if not insitu:
412 return scan
Note: See TracBrowser for help on using the repository browser.