source: tags/asap2alpha/python/asaplotbase.py

Last change on this file was 710, checked in by mar637, 19 years ago

create_mask now also handles args[0]=list. auto_quotient checks for conformance between # of ons and offs

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 19.0 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15
16class asaplotbase:
17    """
18    ASAP plotting base class based on matplotlib.
19    """
20
21    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
22        """
23        Create a new instance of the ASAPlot plotting class.
24
25        If rows < 1 then a separate call to set_panels() is required to define
26        the panel layout; refer to the doctext for set_panels().
27        """
28        self.is_dead = False
29        self.figure = Figure(figsize=size, facecolor='#ddddee')
30        self.canvas = None
31
32        self.set_title(title)
33        self.subplots = []
34        if rows > 0:
35            self.set_panels(rows, cols)
36
37        # Set matplotlib default colour sequence.
38        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
39       
40        c = asaprcParams['plotter.colours']
41        if isinstance(c,str) and len(c) > 0:
42            self.colormap = c.split()
43
44        self.lsalias = {"line":  [1,0],
45                        "dashdot": [4,2,1,2],
46                        "dashed" : [4,2,4,2],
47                        "dotted" : [1,2],
48                        "dashdotdot": [4,2,1,2,1,2],
49                        "dashdashdot": [4,2,4,2,1,2]
50                        }
51
52        styles = "line dashed dotted dashdot".split()
53        c = asaprcParams['plotter.linestyles']
54        if isinstance(c,str) and len(c) > 0:
55            styles = c.split()
56        s = []
57        for ls in styles:
58            if self.lsalias.has_key(ls):
59                s.append(self.lsalias.get(ls))
60            else:
61                s.append('-')
62        self.linestyles = s
63
64        self.color = 0;
65        self.linestyle = 0;
66        self.attributes = {}
67        self.loc = 0
68
69        self.buffering = buffering
70
71    def clear(self):
72        """
73        Delete all lines from the plot.  Line numbering will restart from 1.
74        """
75
76        for i in range(len(self.lines)):
77           self.delete(i)
78        self.axes.clear()
79        self.color = 0
80        self.lines = []
81
82    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
83        if colormap:
84            if isinstance(colormap,list):
85                self.colormap = colormap
86            elif isinstance(colormap,str):
87                self.colormap = colormap.split()
88        if 0 <= color < len(self.colormap):
89            self.color = color
90        if linestyles:
91            self.linestyles = []
92            if isinstance(linestyles,list):
93                styles = linestyles
94            elif isinstance(linestyles,str):
95                styles = linestyles.split()
96            for ls in styles:
97                if self.lsalias.has_key(ls):
98                    self.linestyles.append(self.lsalias.get(ls))
99                else:
100                    self.linestyles.append(self.lsalias.get('line'))
101        if 0 <= linestyle < len(self.linestyles):
102            self.linestyle = linestyle
103
104    def delete(self, numbers=None):
105        """
106        Delete the 0-relative line number, default is to delete the last.
107        The remaining lines are NOT renumbered.
108        """
109
110        if numbers is None: numbers = [len(self.lines)-1]
111
112        if not hasattr(numbers, '__iter__'):
113            numbers = [numbers]
114
115        for number in numbers:
116            if 0 <= number < len(self.lines):
117                if self.lines[number] is not None:
118                    for line in self.lines[number]:
119                        line.set_linestyle('None')
120                        self.lines[number] = None
121        self.show()       
122
123    def get_line(self):
124        """
125        Get the current default line attributes.
126        """
127        return self.attributes
128
129
130    def hist(self, x=None, y=None, fmt=None):
131        """
132        Plot a histogram.  N.B. the x values refer to the start of the
133        histogram bin.
134
135        fmt is the line style as in plot().
136        """
137
138        if x is None:
139            if y is None: return
140            x = range(0,len(y))
141
142        if len(x) != len(y):
143            return
144
145        l2 = 2*len(x)
146        x2 = range(0,l2)
147        y2 = range(0,l2)
148
149        for i in range(0,l2):
150            x2[i] = x[i/2]
151
152        y2[0] = 0
153        for i in range(1,l2):
154            y2[i] = y[(i-1)/2]
155
156        self.plot(x2, y2, fmt)
157
158
159    def hold(self, hold=True):
160        """
161        Buffer graphics until subsequently released.
162        """
163        self.buffering = hold
164
165
166    def legend(self, loc=None):
167        """
168        Add a legend to the plot.
169
170        Any other value for loc else disables the legend:
171             1: upper right
172             2: upper left
173             3: lower left
174             4: lower right
175             5: right
176             6: center left
177             7: center right
178             8: lower center
179             9: upper center
180            10: center
181
182        """
183        if isinstance(loc,int):
184            if 0 > loc > 10: loc = 0
185            self.loc = loc
186        self.show()
187
188
189    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
190        """
191        Plot the next line in the current frame using the current line
192        attributes.  The ASAPlot graphics window will be mapped and raised.
193
194        The argument list works a bit like the matlab plot() function.
195        """
196
197        if x is None:
198            if y is None: return
199            x = range(len(y))
200
201        elif y is None:
202            y = x
203            x = range(len(y))
204
205        if mask is None:
206            if fmt is None:
207                line = self.axes.plot(x, y)
208            else:
209                line = self.axes.plot(x, y, fmt)
210        else:
211            segments = []
212
213            mask = list(mask)
214            i = 0
215            while mask[i:].count(1):
216                i += mask[i:].index(1)
217                if mask[i:].count(0):
218                    j = i + mask[i:].index(0)
219                else:
220                    j = len(mask)
221
222                segments.append(x[i:j])
223                segments.append(y[i:j])
224
225                i = j
226
227            line = self.axes.plot(*segments)
228
229        # Add to an existing line?
230        if add is None or len(self.lines) < add < 0:
231            # Don't add.
232            self.lines.append(line)
233            i = len(self.lines) - 1
234        else:
235            if add == 0: add = len(self.lines)
236            i = add - 1
237            self.lines[i].extend(line)
238
239        # Set/reset attributes for the line.
240        gotcolour = False
241        for k, v in self.attributes.iteritems():
242            if k == 'color': gotcolour = True
243            for segment in self.lines[i]:
244                getattr(segment, "set_%s"%k)(v)
245
246        if not gotcolour and len(self.colormap):
247            for segment in self.lines[i]:
248                getattr(segment, "set_color")(self.colormap[self.color])
249                if len(self.colormap)  == 1:
250                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
251            self.color += 1
252            if self.color >= len(self.colormap):
253                self.color = 0
254
255            if len(self.colormap) == 1:
256                self.linestyle += 1
257            if self.linestyle >= len(self.linestyles):
258                self.linestyle = 0               
259
260        self.show()
261
262
263    def position(self):
264        """
265        Use the mouse to get a position from a graph.
266        """
267
268        def position_disable(event):
269            self.register('button_press', None)
270            print '%.4f, %.4f' % (event.xdata, event.ydata)
271
272        print 'Press any mouse button...'
273        self.register('button_press', position_disable)
274
275
276    def region(self):
277        """
278        Use the mouse to get a rectangular region from a plot.
279
280        The return value is [x0, y0, x1, y1] in world coordinates.
281        """
282
283        def region_start(event):
284            height = self.canvas.figure.bbox.height()
285            self.rect = {'fig': None, 'height': height,
286                         'x': event.x, 'y': height - event.y,
287                         'world': [event.xdata, event.ydata,
288                                   event.xdata, event.ydata]}
289            self.register('button_press', None)
290            self.register('motion_notify', region_draw)
291            self.register('button_release', region_disable)
292
293        def region_draw(event):
294            self.canvas._tkcanvas.delete(self.rect['fig'])
295            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
296                                self.rect['x'], self.rect['y'],
297                                event.x, self.rect['height'] - event.y)
298
299        def region_disable(event):
300            self.register('motion_notify', None)
301            self.register('button_release', None)
302
303            self.canvas._tkcanvas.delete(self.rect['fig'])
304
305            self.rect['world'][2:4] = [event.xdata, event.ydata]
306            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
307                self.rect['world'][1], self.rect['world'][2],
308                self.rect['world'][3])
309
310        self.register('button_press', region_start)
311
312        # This has to be modified to block and return the result (currently
313        # printed by region_disable) when that becomes possible in matplotlib.
314
315        return [0.0, 0.0, 0.0, 0.0]
316
317
318    def register(self, type=None, func=None):
319        """
320        Register, reregister, or deregister events of type 'button_press',
321        'button_release', or 'motion_notify'.
322
323        The specified callback function should have the following signature:
324
325            def func(event)
326
327        where event is an MplEvent instance containing the following data:
328
329            name                # Event name.
330            canvas              # FigureCanvas instance generating the event.
331            x      = None       # x position - pixels from left of canvas.
332            y      = None       # y position - pixels from bottom of canvas.
333            button = None       # Button pressed: None, 1, 2, 3.
334            key    = None       # Key pressed: None, chr(range(255)), shift,
335                                  win, or control
336            inaxes = None       # Axes instance if cursor within axes.
337            xdata  = None       # x world coordinate.
338            ydata  = None       # y world coordinate.
339
340        For example:
341
342            def mouse_move(event):
343                print event.xdata, event.ydata
344
345            a = asaplot()
346            a.register('motion_notify', mouse_move)
347
348        If func is None, the event is deregistered.
349
350        Note that in TkAgg keyboard button presses don't generate an event.
351        """
352
353        if not self.events.has_key(type): return
354
355        if func is None:
356            if self.events[type] is not None:
357                # It's not clear that this does anything.
358                self.canvas.mpl_disconnect(self.events[type])
359                self.events[type] = None
360
361                # It seems to be necessary to return events to the toolbar.
362                if type == 'motion_notify':
363                    self.canvas.mpl_connect(type + '_event',
364                        self.figmgr.toolbar.mouse_move)
365                elif type == 'button_press':
366                    self.canvas.mpl_connect(type + '_event',
367                        self.figmgr.toolbar.press)
368                elif type == 'button_release':
369                    self.canvas.mpl_connect(type + '_event',
370                        self.figmgr.toolbar.release)
371
372        else:
373            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
374
375
376    def release(self):
377        """
378        Release buffered graphics.
379        """
380        self.buffering = False
381        self.show()
382
383
384    def save(self, fname=None, orientation=None, dpi=None):
385        """
386        Save the plot to a file.
387
388        fname is the name of the output file.  The image format is determined
389        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
390        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
391        directory.
392        """
393        if fname is None:
394            from datetime import datetime
395            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
396            fname = 'asap'+dstr+'.png'
397
398        d = ['png','.ps','eps']
399
400        from os.path import expandvars
401        fname = expandvars(fname)
402
403        if fname[-3:].lower() in d:
404            try:
405                if fname[-3:].lower() == ".ps":
406                    w = self.figure.figwidth.get()
407                    h = self.figure.figheight.get()                       
408                    a4w = 8.25
409                    a4h = 11.25
410                   
411                    if orientation is None:
412                        # auto oriented
413                        if w > h:
414                            orientation = 'landscape'
415                        else:
416                            orientation = 'portrait'
417                    ds = None
418                    if orientation == 'landscape':
419                        ds = min(a4h/w,a4w/h)
420                    else:
421                        ds = min(a4w/w,a4h/h)
422                    ow = ds * w
423                    oh = ds * h
424                    self.figure.set_figsize_inches((ow,oh))
425                    self.canvas.print_figure(fname,orientation=orientation)
426                    print 'Written file %s' % (fname)
427                else:                   
428                    if dpi is None:
429                        dpi =150
430                    self.canvas.print_figure(fname,dpi=dpi)
431                    print 'Written file %s' % (fname)
432            except IOError, msg:
433                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
434                return
435        else:
436            print "Invalid image type. Valid types are:"
437            print "'ps', 'eps', 'png'"
438
439
440    def set_axes(self, what=None, *args, **kwargs):
441        """
442        Set attributes for the axes by calling the relevant Axes.set_*()
443        method.  Colour translation is done as described in the doctext
444        for palette().
445        """
446
447        if what is None: return
448        if what[-6:] == 'colour': what = what[:-6] + 'color'
449
450        newargs = {}
451       
452        for k, v in kwargs.iteritems():
453            k = k.lower()
454            if k == 'colour': k = 'color'
455            newargs[k] = v
456
457        getattr(self.axes, "set_%s"%what)(*args, **newargs)
458        self.show()
459
460
461    def set_figure(self, what=None, *args, **kwargs):
462        """
463        Set attributes for the figure by calling the relevant Figure.set_*()
464        method.  Colour translation is done as described in the doctext
465        for palette().
466        """
467
468        if what is None: return
469        if what[-6:] == 'colour': what = what[:-6] + 'color'
470        #if what[-5:] == 'color' and len(args):
471        #    args = (get_colour(args[0]),)
472
473        newargs = {}
474        for k, v in kwargs.iteritems():
475            k = k.lower()
476            if k == 'colour': k = 'color'
477            newargs[k] = v
478
479        getattr(self.figure, "set_%s"%what)(*args, **newargs)
480        self.show()
481
482
483    def set_limits(self, xlim=None, ylim=None):
484        """
485        Set x-, and y-limits for each subplot.
486
487        xlim = [xmin, xmax] as in axes.set_xlim().
488        ylim = [ymin, ymax] as in axes.set_ylim().
489        """
490        for s in self.subplots:
491            self.axes  = s['axes']
492            self.lines = s['lines']
493            oldxlim =  list(self.axes.get_xlim())
494            oldylim =  list(self.axes.get_ylim())
495            if xlim is not None:
496                for i in range(len(xlim)):
497                    if xlim[i] is not None:
498                        oldxlim[i] = xlim[i]
499            if ylim is not None:                       
500                for i in range(len(ylim)):
501                    if ylim[i] is not None:
502                        oldylim[i] = ylim[i]
503            self.axes.set_xlim(oldxlim)
504            self.axes.set_ylim(oldylim)
505        return
506
507
508    def set_line(self, number=None, **kwargs):
509        """
510        Set attributes for the specified line, or else the next line(s)
511        to be plotted.
512
513        number is the 0-relative number of a line that has already been
514        plotted.  If no such line exists, attributes are recorded and used
515        for the next line(s) to be plotted.
516
517        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
518
519            import matplotlib
520            help(matplotlib.lines)
521
522        The set_* methods of class Line2D define the attribute names and
523        values.  For non-US usage, "colour" is recognized as synonymous with
524        "color".
525
526        Set the value to None to delete an attribute.
527
528        Colour translation is done as described in the doctext for palette().
529        """
530
531        redraw = False
532        for k, v in kwargs.iteritems():
533            k = k.lower()
534            if k == 'colour': k = 'color'
535
536            if 0 <= number < len(self.lines):
537                if self.lines[number] is not None:
538                    for line in self.lines[number]:
539                        getattr(line, "set_%s"%k)(v)
540                    redraw = True
541            else:
542                if v is None:
543                    del self.attributes[k]
544                else:
545                    self.attributes[k] = v
546
547        if redraw: self.show()
548
549
550    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
551        """
552        Set the panel layout.
553
554        rows and cols, if cols != 0, specify the number of rows and columns in
555        a regular layout.   (Indexing of these panels in matplotlib is row-
556        major, i.e. column varies fastest.)
557
558        cols == 0 is interpreted as a retangular layout that accomodates
559        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
560        rows == 2, cols == 3.
561
562        0 <= n < rows*cols is interpreted as the 0-relative panel number in
563        the configuration specified by rows and cols to be added to the
564        current figure as its next 0-relative panel number (i).  This allows
565        non-regular panel layouts to be constructed via multiple calls.  Any
566        other value of n clears the plot and produces a rectangular array of
567        empty panels.  The number of these may be limited by nplots.
568        """
569        if n < 0 and len(self.subplots):
570            self.figure.clear()
571            self.set_title()
572
573        if rows < 1: rows = 1
574
575        if cols <= 0:
576            i = int(sqrt(rows))
577            if i*i < rows: i += 1
578            cols = i
579
580            if i*(i-1) >= rows: i -= 1
581            rows = i
582
583        if 0 <= n < rows*cols:
584            i = len(self.subplots)
585            self.subplots.append({})
586
587            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
588                                            cols, n+1)
589            self.subplots[i]['lines'] = []
590
591            if i == 0: self.subplot(0)
592
593            self.rows = 0
594            self.cols = 0
595
596        else:
597            self.subplots = []
598
599            if nplots < 1 or rows*cols < nplots:
600                nplots = rows*cols
601
602            for i in range(nplots):
603                self.subplots.append({})
604
605                self.subplots[i]['axes']  = self.figure.add_subplot(rows,
606                                                cols, i+1)
607                self.subplots[i]['lines'] = []
608                xfsize = self.subplots[i]['axes'].xaxis.label.get_size()-cols/2
609                yfsize = self.subplots[i]['axes'].yaxis.label.get_size()-rows/2
610                self.subplots[i]['axes'].xaxis.label.set_size(xfsize)
611                self.subplots[i]['axes'].yaxis.label.set_size(yfsize)
612               
613                if ganged:
614                    if rows > 1 or cols > 1:
615                        # Squeeze the plots together.
616                        pos = self.subplots[i]['axes'].get_position()
617                        if cols > 1: pos[2] *= 1.2
618                        if rows > 1: pos[3] *= 1.2
619                        self.subplots[i]['axes'].set_position(pos)
620
621                    # Suppress tick labelling for interior subplots.
622                    if i <= (rows-1)*cols - 1:
623                        if i+cols < nplots:
624                            # Suppress x-labels for frames width
625                            # adjacent frames
626                            for tick in \
627                                    self.subplots[i]['axes'].xaxis.majorTicks:
628                                tick.label1On = False
629                                self.subplots[i]['axes'].xaxis.label.set_visible(False)
630                    if i%cols:
631                        # Suppress y-labels for frames not in the left column.
632                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
633                            tick.label1On = False
634                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
635                       
636
637                self.rows = rows
638                self.cols = cols
639
640            self.subplot(0)
641
642    def set_title(self, title=None):
643        """
644        Set the title of the plot window.  Use the previous title if title is
645        omitted.
646        """
647        if title is not None:
648            self.title = title
649
650        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
651
652
653    def show(self):
654        """
655        Show graphics dependent on the current buffering state.
656        """
657        if not self.buffering:
658            if self.loc is not None:
659                for j in range(len(self.subplots)):
660                    lines  = []
661                    labels = []
662                    i = 0
663                    for line in self.subplots[j]['lines']:
664                        i += 1
665                        if line is not None:
666                            lines.append(line[0])
667                            lbl = line[0].get_label()
668                            if lbl == '':
669                                lbl = str(i)
670                            labels.append(lbl)
671
672                    if len(lines):
673                        self.subplots[j]['axes'].legend(tuple(lines),
674                                                        tuple(labels),
675                                                        self.loc)
676                    else:
677                        self.subplots[j]['axes'].legend((' '))
678
679
680    def subplot(self, i=None, inc=None):
681        """
682        Set the subplot to the 0-relative panel number as defined by one or
683        more invokations of set_panels().
684        """
685        l = len(self.subplots)
686        if l:
687            if i is not None:
688                self.i = i
689
690            if inc is not None:
691                self.i += inc
692
693            self.i %= l
694            self.axes  = self.subplots[self.i]['axes']
695            self.lines = self.subplots[self.i]['lines']
696
697
698    def text(self, *args, **kwargs):
699        """
700        Add text to the figure.
701        """
702        self.figure.text(*args, **kwargs)
703        self.show()
Note: See TracBrowser for help on using the repository browser.