source: trunk/python/asaplotbase.py @ 1019

Last change on this file since 1019 was 1019, checked in by mar637, 18 years ago

some work on the multipanel cosmetics. Use font-scaling on labels, also drop the last xaxis tick if there is an adjacent panel.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.0 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15
16from matplotlib.ticker import ScalarFormatter
17from matplotlib.ticker import NullLocator
18
19class MyFormatter(ScalarFormatter):
20    def __call__(self, x, pos=None):
21        last = len(self.locs)-2
22        if pos==last:
23            print "Diabling tick no " , pos, last
24            return ''  # pos=-1 is the last tick
25        else: return ScalarFormatter.__call__(self, x, pos)
26
27class asaplotbase:
28    """
29    ASAP plotting base class based on matplotlib.
30    """
31
32    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
33        """
34        Create a new instance of the ASAPlot plotting class.
35
36        If rows < 1 then a separate call to set_panels() is required to define
37        the panel layout; refer to the doctext for set_panels().
38        """
39        self.is_dead = False
40        self.figure = Figure(figsize=size, facecolor='#ddddee')
41        self.canvas = None
42
43        self.set_title(title)
44        self.subplots = []
45        if rows > 0:
46            self.set_panels(rows, cols)
47
48        # Set matplotlib default colour sequence.
49        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
50
51        c = asaprcParams['plotter.colours']
52        if isinstance(c,str) and len(c) > 0:
53            self.colormap = c.split()
54
55        self.lsalias = {"line":  [1,0],
56                        "dashdot": [4,2,1,2],
57                        "dashed" : [4,2,4,2],
58                        "dotted" : [1,2],
59                        "dashdotdot": [4,2,1,2,1,2],
60                        "dashdashdot": [4,2,4,2,1,2]
61                        }
62
63        styles = "line dashed dotted dashdot".split()
64        c = asaprcParams['plotter.linestyles']
65        if isinstance(c,str) and len(c) > 0:
66            styles = c.split()
67        s = []
68        for ls in styles:
69            if self.lsalias.has_key(ls):
70                s.append(self.lsalias.get(ls))
71            else:
72                s.append('-')
73        self.linestyles = s
74
75        self.color = 0;
76        self.linestyle = 0;
77        self.attributes = {}
78        self.loc = 0
79
80        self.buffering = buffering
81
82    def clear(self):
83        """
84        Delete all lines from the plot.  Line numbering will restart from 1.
85        """
86
87        for i in range(len(self.lines)):
88           self.delete(i)
89        self.axes.clear()
90        self.color = 0
91        self.lines = []
92
93    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
94        if colormap:
95            if isinstance(colormap,list):
96                self.colormap = colormap
97            elif isinstance(colormap,str):
98                self.colormap = colormap.split()
99        if 0 <= color < len(self.colormap):
100            self.color = color
101        if linestyles:
102            self.linestyles = []
103            if isinstance(linestyles,list):
104                styles = linestyles
105            elif isinstance(linestyles,str):
106                styles = linestyles.split()
107            for ls in styles:
108                if self.lsalias.has_key(ls):
109                    self.linestyles.append(self.lsalias.get(ls))
110                else:
111                    self.linestyles.append(self.lsalias.get('line'))
112        if 0 <= linestyle < len(self.linestyles):
113            self.linestyle = linestyle
114
115    def delete(self, numbers=None):
116        """
117        Delete the 0-relative line number, default is to delete the last.
118        The remaining lines are NOT renumbered.
119        """
120
121        if numbers is None: numbers = [len(self.lines)-1]
122
123        if not hasattr(numbers, '__iter__'):
124            numbers = [numbers]
125
126        for number in numbers:
127            if 0 <= number < len(self.lines):
128                if self.lines[number] is not None:
129                    for line in self.lines[number]:
130                        line.set_linestyle('None')
131                        self.lines[number] = None
132        self.show()
133
134    def get_line(self):
135        """
136        Get the current default line attributes.
137        """
138        return self.attributes
139
140
141    def hist(self, x=None, y=None, fmt=None):
142        """
143        Plot a histogram.  N.B. the x values refer to the start of the
144        histogram bin.
145
146        fmt is the line style as in plot().
147        """
148
149        if x is None:
150            if y is None: return
151            x = range(0,len(y))
152
153        if len(x) != len(y):
154            return
155
156        l2 = 2*len(x)
157        x2 = range(0,l2)
158        y2 = range(0,l2)
159
160        for i in range(0,l2):
161            x2[i] = x[i/2]
162
163        y2[0] = 0
164        for i in range(1,l2):
165            y2[i] = y[(i-1)/2]
166
167        self.plot(x2, y2, fmt)
168
169
170    def hold(self, hold=True):
171        """
172        Buffer graphics until subsequently released.
173        """
174        self.buffering = hold
175
176
177    def legend(self, loc=None):
178        """
179        Add a legend to the plot.
180
181        Any other value for loc else disables the legend:
182             1: upper right
183             2: upper left
184             3: lower left
185             4: lower right
186             5: right
187             6: center left
188             7: center right
189             8: lower center
190             9: upper center
191            10: center
192
193        """
194        if isinstance(loc,int):
195            if 0 > loc > 10: loc = 0
196            self.loc = loc
197        self.show()
198
199
200    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
201        """
202        Plot the next line in the current frame using the current line
203        attributes.  The ASAPlot graphics window will be mapped and raised.
204
205        The argument list works a bit like the matlab plot() function.
206        """
207
208        if x is None:
209            if y is None: return
210            x = range(len(y))
211
212        elif y is None:
213            y = x
214            x = range(len(y))
215
216        if mask is None:
217            if fmt is None:
218                line = self.axes.plot(x, y)
219            else:
220                line = self.axes.plot(x, y, fmt)
221        else:
222            segments = []
223
224            mask = list(mask)
225            i = 0
226            while mask[i:].count(1):
227                i += mask[i:].index(1)
228                if mask[i:].count(0):
229                    j = i + mask[i:].index(0)
230                else:
231                    j = len(mask)
232
233                segments.append(x[i:j])
234                segments.append(y[i:j])
235
236                i = j
237
238            line = self.axes.plot(*segments)
239        # Add to an existing line?
240        if add is None or len(self.lines) < add < 0:
241            # Don't add.
242            self.lines.append(line)
243            i = len(self.lines) - 1
244        else:
245            if add == 0: add = len(self.lines)
246            i = add - 1
247            self.lines[i].extend(line)
248
249        # Set/reset attributes for the line.
250        gotcolour = False
251        for k, v in self.attributes.iteritems():
252            if k == 'color': gotcolour = True
253            for segment in self.lines[i]:
254                getattr(segment, "set_%s"%k)(v)
255
256        if not gotcolour and len(self.colormap):
257            for segment in self.lines[i]:
258                getattr(segment, "set_color")(self.colormap[self.color])
259                if len(self.colormap)  == 1:
260                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
261            self.color += 1
262            if self.color >= len(self.colormap):
263                self.color = 0
264
265            if len(self.colormap) == 1:
266                self.linestyle += 1
267            if self.linestyle >= len(self.linestyles):
268                self.linestyle = 0
269
270        self.show()
271
272
273    def position(self):
274        """
275        Use the mouse to get a position from a graph.
276        """
277
278        def position_disable(event):
279            self.register('button_press', None)
280            print '%.4f, %.4f' % (event.xdata, event.ydata)
281
282        print 'Press any mouse button...'
283        self.register('button_press', position_disable)
284
285
286    def region(self):
287        """
288        Use the mouse to get a rectangular region from a plot.
289
290        The return value is [x0, y0, x1, y1] in world coordinates.
291        """
292
293        def region_start(event):
294            height = self.canvas.figure.bbox.height()
295            self.rect = {'fig': None, 'height': height,
296                         'x': event.x, 'y': height - event.y,
297                         'world': [event.xdata, event.ydata,
298                                   event.xdata, event.ydata]}
299            self.register('button_press', None)
300            self.register('motion_notify', region_draw)
301            self.register('button_release', region_disable)
302
303        def region_draw(event):
304            self.canvas._tkcanvas.delete(self.rect['fig'])
305            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
306                                self.rect['x'], self.rect['y'],
307                                event.x, self.rect['height'] - event.y)
308
309        def region_disable(event):
310            self.register('motion_notify', None)
311            self.register('button_release', None)
312
313            self.canvas._tkcanvas.delete(self.rect['fig'])
314
315            self.rect['world'][2:4] = [event.xdata, event.ydata]
316            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
317                self.rect['world'][1], self.rect['world'][2],
318                self.rect['world'][3])
319
320        self.register('button_press', region_start)
321
322        # This has to be modified to block and return the result (currently
323        # printed by region_disable) when that becomes possible in matplotlib.
324
325        return [0.0, 0.0, 0.0, 0.0]
326
327
328    def register(self, type=None, func=None):
329        """
330        Register, reregister, or deregister events of type 'button_press',
331        'button_release', or 'motion_notify'.
332
333        The specified callback function should have the following signature:
334
335            def func(event)
336
337        where event is an MplEvent instance containing the following data:
338
339            name                # Event name.
340            canvas              # FigureCanvas instance generating the event.
341            x      = None       # x position - pixels from left of canvas.
342            y      = None       # y position - pixels from bottom of canvas.
343            button = None       # Button pressed: None, 1, 2, 3.
344            key    = None       # Key pressed: None, chr(range(255)), shift,
345                                  win, or control
346            inaxes = None       # Axes instance if cursor within axes.
347            xdata  = None       # x world coordinate.
348            ydata  = None       # y world coordinate.
349
350        For example:
351
352            def mouse_move(event):
353                print event.xdata, event.ydata
354
355            a = asaplot()
356            a.register('motion_notify', mouse_move)
357
358        If func is None, the event is deregistered.
359
360        Note that in TkAgg keyboard button presses don't generate an event.
361        """
362
363        if not self.events.has_key(type): return
364
365        if func is None:
366            if self.events[type] is not None:
367                # It's not clear that this does anything.
368                self.canvas.mpl_disconnect(self.events[type])
369                self.events[type] = None
370
371                # It seems to be necessary to return events to the toolbar.
372                if type == 'motion_notify':
373                    self.canvas.mpl_connect(type + '_event',
374                        self.figmgr.toolbar.mouse_move)
375                elif type == 'button_press':
376                    self.canvas.mpl_connect(type + '_event',
377                        self.figmgr.toolbar.press)
378                elif type == 'button_release':
379                    self.canvas.mpl_connect(type + '_event',
380                        self.figmgr.toolbar.release)
381
382        else:
383            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
384
385
386    def release(self):
387        """
388        Release buffered graphics.
389        """
390        self.buffering = False
391        self.show()
392
393
394    def save(self, fname=None, orientation=None, dpi=None):
395        """
396        Save the plot to a file.
397
398        fname is the name of the output file.  The image format is determined
399        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
400        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
401        directory.
402        """
403        if fname is None:
404            from datetime import datetime
405            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
406            fname = 'asap'+dstr+'.png'
407
408        d = ['png','.ps','eps']
409
410        from os.path import expandvars
411        fname = expandvars(fname)
412
413        if fname[-3:].lower() in d:
414            try:
415                if fname[-3:].lower() == ".ps":
416                    w = self.figure.figwidth.get()
417                    h = self.figure.figheight.get()
418                    a4w = 8.25
419                    a4h = 11.25
420
421                    if orientation is None:
422                        # auto oriented
423                        if w > h:
424                            orientation = 'landscape'
425                        else:
426                            orientation = 'portrait'
427                    ds = None
428                    if orientation == 'landscape':
429                        ds = min(a4h/w,a4w/h)
430                    else:
431                        ds = min(a4w/w,a4h/h)
432                    ow = ds * w
433                    oh = ds * h
434                    self.figure.set_figsize_inches((ow,oh))
435                    self.canvas.print_figure(fname,orientation=orientation)
436                    print 'Written file %s' % (fname)
437                else:
438                    if dpi is None:
439                        dpi =150
440                    self.canvas.print_figure(fname,dpi=dpi)
441                    print 'Written file %s' % (fname)
442            except IOError, msg:
443                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
444                return
445        else:
446            print "Invalid image type. Valid types are:"
447            print "'ps', 'eps', 'png'"
448
449
450    def set_axes(self, what=None, *args, **kwargs):
451        """
452        Set attributes for the axes by calling the relevant Axes.set_*()
453        method.  Colour translation is done as described in the doctext
454        for palette().
455        """
456
457        if what is None: return
458        if what[-6:] == 'colour': what = what[:-6] + 'color'
459
460        newargs = {}
461
462        for k, v in kwargs.iteritems():
463            k = k.lower()
464            if k == 'colour': k = 'color'
465            newargs[k] = v
466
467        getattr(self.axes, "set_%s"%what)(*args, **newargs)
468        s = self.axes.title.get_size()
469        tsize = s-(self.cols+self.rows)/2-1
470        self.axes.title.set_size(tsize)
471        if self.cols > 1:
472            xfsize = self.axes.xaxis.label.get_size()-(self.cols+1)/2
473            self.axes.xaxis.label.set_size(xfsize)
474        if self.rows > 1:
475            yfsize = self.axes.yaxis.label.get_size()-(self.rows+1)/2
476            self.axes.yaxis.label.set_size(yfsize)
477
478        self.show()
479
480
481    def set_figure(self, what=None, *args, **kwargs):
482        """
483        Set attributes for the figure by calling the relevant Figure.set_*()
484        method.  Colour translation is done as described in the doctext
485        for palette().
486        """
487
488        if what is None: return
489        if what[-6:] == 'colour': what = what[:-6] + 'color'
490        #if what[-5:] == 'color' and len(args):
491        #    args = (get_colour(args[0]),)
492
493        newargs = {}
494        for k, v in kwargs.iteritems():
495            k = k.lower()
496            if k == 'colour': k = 'color'
497            newargs[k] = v
498
499        getattr(self.figure, "set_%s"%what)(*args, **newargs)
500        self.show()
501
502
503    def set_limits(self, xlim=None, ylim=None):
504        """
505        Set x-, and y-limits for each subplot.
506
507        xlim = [xmin, xmax] as in axes.set_xlim().
508        ylim = [ymin, ymax] as in axes.set_ylim().
509        """
510        for s in self.subplots:
511            self.axes  = s['axes']
512            self.lines = s['lines']
513            oldxlim =  list(self.axes.get_xlim())
514            oldylim =  list(self.axes.get_ylim())
515            if xlim is not None:
516                for i in range(len(xlim)):
517                    if xlim[i] is not None:
518                        oldxlim[i] = xlim[i]
519            if ylim is not None:
520                for i in range(len(ylim)):
521                    if ylim[i] is not None:
522                        oldylim[i] = ylim[i]
523            self.axes.set_xlim(oldxlim)
524            self.axes.set_ylim(oldylim)
525        return
526
527
528    def set_line(self, number=None, **kwargs):
529        """
530        Set attributes for the specified line, or else the next line(s)
531        to be plotted.
532
533        number is the 0-relative number of a line that has already been
534        plotted.  If no such line exists, attributes are recorded and used
535        for the next line(s) to be plotted.
536
537        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
538
539            import matplotlib
540            help(matplotlib.lines)
541
542        The set_* methods of class Line2D define the attribute names and
543        values.  For non-US usage, "colour" is recognized as synonymous with
544        "color".
545
546        Set the value to None to delete an attribute.
547
548        Colour translation is done as described in the doctext for palette().
549        """
550
551        redraw = False
552        for k, v in kwargs.iteritems():
553            k = k.lower()
554            if k == 'colour': k = 'color'
555
556            if 0 <= number < len(self.lines):
557                if self.lines[number] is not None:
558                    for line in self.lines[number]:
559                        getattr(line, "set_%s"%k)(v)
560                    redraw = True
561            else:
562                if v is None:
563                    del self.attributes[k]
564                else:
565                    self.attributes[k] = v
566
567        if redraw: self.show()
568
569
570    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
571        """
572        Set the panel layout.
573
574        rows and cols, if cols != 0, specify the number of rows and columns in
575        a regular layout.   (Indexing of these panels in matplotlib is row-
576        major, i.e. column varies fastest.)
577
578        cols == 0 is interpreted as a retangular layout that accomodates
579        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
580        rows == 2, cols == 3.
581
582        0 <= n < rows*cols is interpreted as the 0-relative panel number in
583        the configuration specified by rows and cols to be added to the
584        current figure as its next 0-relative panel number (i).  This allows
585        non-regular panel layouts to be constructed via multiple calls.  Any
586        other value of n clears the plot and produces a rectangular array of
587        empty panels.  The number of these may be limited by nplots.
588        """
589        if n < 0 and len(self.subplots):
590            self.figure.clear()
591            self.set_title()
592
593        if rows < 1: rows = 1
594
595        if cols <= 0:
596            i = int(sqrt(rows))
597            if i*i < rows: i += 1
598            cols = i
599
600            if i*(i-1) >= rows: i -= 1
601            rows = i
602
603        if 0 <= n < rows*cols:
604            i = len(self.subplots)
605            self.subplots.append({})
606
607            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
608                                            cols, n+1)
609            self.subplots[i]['lines'] = []
610
611            if i == 0: self.subplot(0)
612
613            self.rows = 0
614            self.cols = 0
615
616        else:
617            self.subplots = []
618
619            if nplots < 1 or rows*cols < nplots:
620                nplots = rows*cols
621
622            for i in range(nplots):
623                self.subplots.append({})
624
625                self.subplots[i]['axes']  = self.figure.add_subplot(rows,
626                                                cols, i+1)
627                self.subplots[i]['lines'] = []
628
629                if ganged:
630                    if rows > 1 or cols > 1:
631                        # Squeeze the plots together.
632                        pos = self.subplots[i]['axes'].get_position()
633                        if cols > 1: pos[2] *= 1.2
634                        if rows > 1: pos[3] *= 1.2
635                        self.subplots[i]['axes'].set_position(pos)
636
637                    # Suppress tick labelling for interior subplots.
638                    if i <= (rows-1)*cols - 1:
639                        if i+cols < nplots:
640                            # Suppress x-labels for frames width
641                            # adjacent frames
642                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
643                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
644                    if i%cols:
645                        # Suppress y-labels for frames not in the left column.
646                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
647                            tick.label1On = False
648                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
649                    if (i+1)%cols:
650                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
651                self.rows = rows
652                self.cols = cols
653
654            self.subplot(0)
655
656    def set_title(self, title=None):
657        """
658        Set the title of the plot window.  Use the previous title if title is
659        omitted.
660        """
661        if title is not None:
662            self.title = title
663
664        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
665
666
667    def show(self):
668        """
669        Show graphics dependent on the current buffering state.
670        """
671        if not self.buffering:
672            if self.loc is not None:
673                for j in range(len(self.subplots)):
674                    lines  = []
675                    labels = []
676                    i = 0
677                    for line in self.subplots[j]['lines']:
678                        i += 1
679                        if line is not None:
680                            lines.append(line[0])
681                            lbl = line[0].get_label()
682                            if lbl == '':
683                                lbl = str(i)
684                            labels.append(lbl)
685
686                    if len(lines):
687                        self.subplots[j]['axes'].legend(tuple(lines),
688                                                        tuple(labels),
689                                                        self.loc)
690                    else:
691                        self.subplots[j]['axes'].legend((' '))
692
693
694    def subplot(self, i=None, inc=None):
695        """
696        Set the subplot to the 0-relative panel number as defined by one or
697        more invokations of set_panels().
698        """
699        l = len(self.subplots)
700        if l:
701            if i is not None:
702                self.i = i
703
704            if inc is not None:
705                self.i += inc
706
707            self.i %= l
708            self.axes  = self.subplots[self.i]['axes']
709            self.lines = self.subplots[self.i]['lines']
710
711
712    def text(self, *args, **kwargs):
713        """
714        Add text to the figure.
715        """
716        self.figure.text(*args, **kwargs)
717        self.show()
Note: See TracBrowser for help on using the repository browser.