source: trunk/python/asaplotbase.py @ 1020

Last change on this file since 1020 was 1020, checked in by mar637, 18 years ago

made matplotlib hack in save version dependend as this has been fixed in >0.85.*

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.2 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15
16from matplotlib.ticker import ScalarFormatter
17from matplotlib.ticker import NullLocator
18
19class MyFormatter(ScalarFormatter):
20    def __call__(self, x, pos=None):
21        last = len(self.locs)-2
22        if pos==last:
23            print "Diabling tick no " , pos, last
24            return ''  # pos=-1 is the last tick
25        else: return ScalarFormatter.__call__(self, x, pos)
26
27class asaplotbase:
28    """
29    ASAP plotting base class based on matplotlib.
30    """
31
32    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
33        """
34        Create a new instance of the ASAPlot plotting class.
35
36        If rows < 1 then a separate call to set_panels() is required to define
37        the panel layout; refer to the doctext for set_panels().
38        """
39        self.is_dead = False
40        self.figure = Figure(figsize=size, facecolor='#ddddee')
41        self.canvas = None
42
43        self.set_title(title)
44        self.subplots = []
45        if rows > 0:
46            self.set_panels(rows, cols)
47
48        # Set matplotlib default colour sequence.
49        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
50
51        c = asaprcParams['plotter.colours']
52        if isinstance(c,str) and len(c) > 0:
53            self.colormap = c.split()
54
55        self.lsalias = {"line":  [1,0],
56                        "dashdot": [4,2,1,2],
57                        "dashed" : [4,2,4,2],
58                        "dotted" : [1,2],
59                        "dashdotdot": [4,2,1,2,1,2],
60                        "dashdashdot": [4,2,4,2,1,2]
61                        }
62
63        styles = "line dashed dotted dashdot".split()
64        c = asaprcParams['plotter.linestyles']
65        if isinstance(c,str) and len(c) > 0:
66            styles = c.split()
67        s = []
68        for ls in styles:
69            if self.lsalias.has_key(ls):
70                s.append(self.lsalias.get(ls))
71            else:
72                s.append('-')
73        self.linestyles = s
74
75        self.color = 0;
76        self.linestyle = 0;
77        self.attributes = {}
78        self.loc = 0
79
80        self.buffering = buffering
81
82    def clear(self):
83        """
84        Delete all lines from the plot.  Line numbering will restart from 1.
85        """
86
87        for i in range(len(self.lines)):
88           self.delete(i)
89        self.axes.clear()
90        self.color = 0
91        self.lines = []
92
93    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
94        if colormap:
95            if isinstance(colormap,list):
96                self.colormap = colormap
97            elif isinstance(colormap,str):
98                self.colormap = colormap.split()
99        if 0 <= color < len(self.colormap):
100            self.color = color
101        if linestyles:
102            self.linestyles = []
103            if isinstance(linestyles,list):
104                styles = linestyles
105            elif isinstance(linestyles,str):
106                styles = linestyles.split()
107            for ls in styles:
108                if self.lsalias.has_key(ls):
109                    self.linestyles.append(self.lsalias.get(ls))
110                else:
111                    self.linestyles.append(self.lsalias.get('line'))
112        if 0 <= linestyle < len(self.linestyles):
113            self.linestyle = linestyle
114
115    def delete(self, numbers=None):
116        """
117        Delete the 0-relative line number, default is to delete the last.
118        The remaining lines are NOT renumbered.
119        """
120
121        if numbers is None: numbers = [len(self.lines)-1]
122
123        if not hasattr(numbers, '__iter__'):
124            numbers = [numbers]
125
126        for number in numbers:
127            if 0 <= number < len(self.lines):
128                if self.lines[number] is not None:
129                    for line in self.lines[number]:
130                        line.set_linestyle('None')
131                        self.lines[number] = None
132        self.show()
133
134    def get_line(self):
135        """
136        Get the current default line attributes.
137        """
138        return self.attributes
139
140
141    def hist(self, x=None, y=None, fmt=None):
142        """
143        Plot a histogram.  N.B. the x values refer to the start of the
144        histogram bin.
145
146        fmt is the line style as in plot().
147        """
148
149        if x is None:
150            if y is None: return
151            x = range(0,len(y))
152
153        if len(x) != len(y):
154            return
155
156        l2 = 2*len(x)
157        x2 = range(0,l2)
158        y2 = range(0,l2)
159
160        for i in range(0,l2):
161            x2[i] = x[i/2]
162
163        y2[0] = 0
164        for i in range(1,l2):
165            y2[i] = y[(i-1)/2]
166
167        self.plot(x2, y2, fmt)
168
169
170    def hold(self, hold=True):
171        """
172        Buffer graphics until subsequently released.
173        """
174        self.buffering = hold
175
176
177    def legend(self, loc=None):
178        """
179        Add a legend to the plot.
180
181        Any other value for loc else disables the legend:
182             1: upper right
183             2: upper left
184             3: lower left
185             4: lower right
186             5: right
187             6: center left
188             7: center right
189             8: lower center
190             9: upper center
191            10: center
192
193        """
194        if isinstance(loc,int):
195            if 0 > loc > 10: loc = 0
196            self.loc = loc
197        self.show()
198
199
200    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
201        """
202        Plot the next line in the current frame using the current line
203        attributes.  The ASAPlot graphics window will be mapped and raised.
204
205        The argument list works a bit like the matlab plot() function.
206        """
207
208        if x is None:
209            if y is None: return
210            x = range(len(y))
211
212        elif y is None:
213            y = x
214            x = range(len(y))
215
216        if mask is None:
217            if fmt is None:
218                line = self.axes.plot(x, y)
219            else:
220                line = self.axes.plot(x, y, fmt)
221        else:
222            segments = []
223
224            mask = list(mask)
225            i = 0
226            while mask[i:].count(1):
227                i += mask[i:].index(1)
228                if mask[i:].count(0):
229                    j = i + mask[i:].index(0)
230                else:
231                    j = len(mask)
232
233                segments.append(x[i:j])
234                segments.append(y[i:j])
235
236                i = j
237
238            line = self.axes.plot(*segments)
239        # Add to an existing line?
240        if add is None or len(self.lines) < add < 0:
241            # Don't add.
242            self.lines.append(line)
243            i = len(self.lines) - 1
244        else:
245            if add == 0: add = len(self.lines)
246            i = add - 1
247            self.lines[i].extend(line)
248
249        # Set/reset attributes for the line.
250        gotcolour = False
251        for k, v in self.attributes.iteritems():
252            if k == 'color': gotcolour = True
253            for segment in self.lines[i]:
254                getattr(segment, "set_%s"%k)(v)
255
256        if not gotcolour and len(self.colormap):
257            for segment in self.lines[i]:
258                getattr(segment, "set_color")(self.colormap[self.color])
259                if len(self.colormap)  == 1:
260                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
261            self.color += 1
262            if self.color >= len(self.colormap):
263                self.color = 0
264
265            if len(self.colormap) == 1:
266                self.linestyle += 1
267            if self.linestyle >= len(self.linestyles):
268                self.linestyle = 0
269
270        self.show()
271
272
273    def position(self):
274        """
275        Use the mouse to get a position from a graph.
276        """
277
278        def position_disable(event):
279            self.register('button_press', None)
280            print '%.4f, %.4f' % (event.xdata, event.ydata)
281
282        print 'Press any mouse button...'
283        self.register('button_press', position_disable)
284
285
286    def region(self):
287        """
288        Use the mouse to get a rectangular region from a plot.
289
290        The return value is [x0, y0, x1, y1] in world coordinates.
291        """
292
293        def region_start(event):
294            height = self.canvas.figure.bbox.height()
295            self.rect = {'fig': None, 'height': height,
296                         'x': event.x, 'y': height - event.y,
297                         'world': [event.xdata, event.ydata,
298                                   event.xdata, event.ydata]}
299            self.register('button_press', None)
300            self.register('motion_notify', region_draw)
301            self.register('button_release', region_disable)
302
303        def region_draw(event):
304            self.canvas._tkcanvas.delete(self.rect['fig'])
305            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
306                                self.rect['x'], self.rect['y'],
307                                event.x, self.rect['height'] - event.y)
308
309        def region_disable(event):
310            self.register('motion_notify', None)
311            self.register('button_release', None)
312
313            self.canvas._tkcanvas.delete(self.rect['fig'])
314
315            self.rect['world'][2:4] = [event.xdata, event.ydata]
316            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
317                self.rect['world'][1], self.rect['world'][2],
318                self.rect['world'][3])
319
320        self.register('button_press', region_start)
321
322        # This has to be modified to block and return the result (currently
323        # printed by region_disable) when that becomes possible in matplotlib.
324
325        return [0.0, 0.0, 0.0, 0.0]
326
327
328    def register(self, type=None, func=None):
329        """
330        Register, reregister, or deregister events of type 'button_press',
331        'button_release', or 'motion_notify'.
332
333        The specified callback function should have the following signature:
334
335            def func(event)
336
337        where event is an MplEvent instance containing the following data:
338
339            name                # Event name.
340            canvas              # FigureCanvas instance generating the event.
341            x      = None       # x position - pixels from left of canvas.
342            y      = None       # y position - pixels from bottom of canvas.
343            button = None       # Button pressed: None, 1, 2, 3.
344            key    = None       # Key pressed: None, chr(range(255)), shift,
345                                  win, or control
346            inaxes = None       # Axes instance if cursor within axes.
347            xdata  = None       # x world coordinate.
348            ydata  = None       # y world coordinate.
349
350        For example:
351
352            def mouse_move(event):
353                print event.xdata, event.ydata
354
355            a = asaplot()
356            a.register('motion_notify', mouse_move)
357
358        If func is None, the event is deregistered.
359
360        Note that in TkAgg keyboard button presses don't generate an event.
361        """
362
363        if not self.events.has_key(type): return
364
365        if func is None:
366            if self.events[type] is not None:
367                # It's not clear that this does anything.
368                self.canvas.mpl_disconnect(self.events[type])
369                self.events[type] = None
370
371                # It seems to be necessary to return events to the toolbar.
372                if type == 'motion_notify':
373                    self.canvas.mpl_connect(type + '_event',
374                        self.figmgr.toolbar.mouse_move)
375                elif type == 'button_press':
376                    self.canvas.mpl_connect(type + '_event',
377                        self.figmgr.toolbar.press)
378                elif type == 'button_release':
379                    self.canvas.mpl_connect(type + '_event',
380                        self.figmgr.toolbar.release)
381
382        else:
383            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
384
385
386    def release(self):
387        """
388        Release buffered graphics.
389        """
390        self.buffering = False
391        self.show()
392
393
394    def save(self, fname=None, orientation=None, dpi=None):
395        """
396        Save the plot to a file.
397
398        fname is the name of the output file.  The image format is determined
399        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
400        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
401        directory.
402        """
403        if fname is None:
404            from datetime import datetime
405            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
406            fname = 'asap'+dstr+'.png'
407
408        d = ['png','.ps','eps']
409
410        from os.path import expandvars
411        fname = expandvars(fname)
412
413        if fname[-3:].lower() in d:
414            try:
415                if fname[-3:].lower() == ".ps":
416                    from matplotlib import __version__ as mv
417                    w = self.figure.figwidth.get()
418                    h = self.figure.figheight.get()
419
420                    if orientation is None:
421                        # auto oriented
422                        if w > h:
423                            orientation = 'landscape'
424                        else:
425                            orientation = 'portrait'
426                    # hack to circument ps bug in eraly versions of mpl
427                    if int(mv.split(".")[1]) < 86:
428                        a4w = 8.25
429                        a4h = 11.25
430                        ds = None
431                        if orientation == 'landscape':
432                            ds = min(a4h/w,a4w/h)
433                        else:
434                            ds = min(a4w/w,a4h/h)
435                        ow = ds * w
436                        oh = ds * h
437                        self.figure.set_figsize_inches((ow,oh))
438                    self.canvas.print_figure(fname,orientation=orientation)
439                    print 'Written file %s' % (fname)
440                else:
441                    if dpi is None:
442                        dpi =150
443                    self.canvas.print_figure(fname,dpi=dpi)
444                    print 'Written file %s' % (fname)
445            except IOError, msg:
446                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
447                return
448        else:
449            print "Invalid image type. Valid types are:"
450            print "'ps', 'eps', 'png'"
451
452
453    def set_axes(self, what=None, *args, **kwargs):
454        """
455        Set attributes for the axes by calling the relevant Axes.set_*()
456        method.  Colour translation is done as described in the doctext
457        for palette().
458        """
459
460        if what is None: return
461        if what[-6:] == 'colour': what = what[:-6] + 'color'
462
463        newargs = {}
464
465        for k, v in kwargs.iteritems():
466            k = k.lower()
467            if k == 'colour': k = 'color'
468            newargs[k] = v
469
470        getattr(self.axes, "set_%s"%what)(*args, **newargs)
471        s = self.axes.title.get_size()
472        tsize = s-(self.cols+self.rows)/2-1
473        self.axes.title.set_size(tsize)
474        if self.cols > 1:
475            xfsize = self.axes.xaxis.label.get_size()-(self.cols+1)/2
476            self.axes.xaxis.label.set_size(xfsize)
477        if self.rows > 1:
478            yfsize = self.axes.yaxis.label.get_size()-(self.rows+1)/2
479            self.axes.yaxis.label.set_size(yfsize)
480
481        self.show()
482
483
484    def set_figure(self, what=None, *args, **kwargs):
485        """
486        Set attributes for the figure by calling the relevant Figure.set_*()
487        method.  Colour translation is done as described in the doctext
488        for palette().
489        """
490
491        if what is None: return
492        if what[-6:] == 'colour': what = what[:-6] + 'color'
493        #if what[-5:] == 'color' and len(args):
494        #    args = (get_colour(args[0]),)
495
496        newargs = {}
497        for k, v in kwargs.iteritems():
498            k = k.lower()
499            if k == 'colour': k = 'color'
500            newargs[k] = v
501
502        getattr(self.figure, "set_%s"%what)(*args, **newargs)
503        self.show()
504
505
506    def set_limits(self, xlim=None, ylim=None):
507        """
508        Set x-, and y-limits for each subplot.
509
510        xlim = [xmin, xmax] as in axes.set_xlim().
511        ylim = [ymin, ymax] as in axes.set_ylim().
512        """
513        for s in self.subplots:
514            self.axes  = s['axes']
515            self.lines = s['lines']
516            oldxlim =  list(self.axes.get_xlim())
517            oldylim =  list(self.axes.get_ylim())
518            if xlim is not None:
519                for i in range(len(xlim)):
520                    if xlim[i] is not None:
521                        oldxlim[i] = xlim[i]
522            if ylim is not None:
523                for i in range(len(ylim)):
524                    if ylim[i] is not None:
525                        oldylim[i] = ylim[i]
526            self.axes.set_xlim(oldxlim)
527            self.axes.set_ylim(oldylim)
528        return
529
530
531    def set_line(self, number=None, **kwargs):
532        """
533        Set attributes for the specified line, or else the next line(s)
534        to be plotted.
535
536        number is the 0-relative number of a line that has already been
537        plotted.  If no such line exists, attributes are recorded and used
538        for the next line(s) to be plotted.
539
540        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
541
542            import matplotlib
543            help(matplotlib.lines)
544
545        The set_* methods of class Line2D define the attribute names and
546        values.  For non-US usage, "colour" is recognized as synonymous with
547        "color".
548
549        Set the value to None to delete an attribute.
550
551        Colour translation is done as described in the doctext for palette().
552        """
553
554        redraw = False
555        for k, v in kwargs.iteritems():
556            k = k.lower()
557            if k == 'colour': k = 'color'
558
559            if 0 <= number < len(self.lines):
560                if self.lines[number] is not None:
561                    for line in self.lines[number]:
562                        getattr(line, "set_%s"%k)(v)
563                    redraw = True
564            else:
565                if v is None:
566                    del self.attributes[k]
567                else:
568                    self.attributes[k] = v
569
570        if redraw: self.show()
571
572
573    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
574        """
575        Set the panel layout.
576
577        rows and cols, if cols != 0, specify the number of rows and columns in
578        a regular layout.   (Indexing of these panels in matplotlib is row-
579        major, i.e. column varies fastest.)
580
581        cols == 0 is interpreted as a retangular layout that accomodates
582        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
583        rows == 2, cols == 3.
584
585        0 <= n < rows*cols is interpreted as the 0-relative panel number in
586        the configuration specified by rows and cols to be added to the
587        current figure as its next 0-relative panel number (i).  This allows
588        non-regular panel layouts to be constructed via multiple calls.  Any
589        other value of n clears the plot and produces a rectangular array of
590        empty panels.  The number of these may be limited by nplots.
591        """
592        if n < 0 and len(self.subplots):
593            self.figure.clear()
594            self.set_title()
595
596        if rows < 1: rows = 1
597
598        if cols <= 0:
599            i = int(sqrt(rows))
600            if i*i < rows: i += 1
601            cols = i
602
603            if i*(i-1) >= rows: i -= 1
604            rows = i
605
606        if 0 <= n < rows*cols:
607            i = len(self.subplots)
608            self.subplots.append({})
609
610            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
611                                            cols, n+1)
612            self.subplots[i]['lines'] = []
613
614            if i == 0: self.subplot(0)
615
616            self.rows = 0
617            self.cols = 0
618
619        else:
620            self.subplots = []
621
622            if nplots < 1 or rows*cols < nplots:
623                nplots = rows*cols
624
625            for i in range(nplots):
626                self.subplots.append({})
627
628                self.subplots[i]['axes']  = self.figure.add_subplot(rows,
629                                                cols, i+1)
630                self.subplots[i]['lines'] = []
631
632                if ganged:
633                    if rows > 1 or cols > 1:
634                        # Squeeze the plots together.
635                        pos = self.subplots[i]['axes'].get_position()
636                        if cols > 1: pos[2] *= 1.2
637                        if rows > 1: pos[3] *= 1.2
638                        self.subplots[i]['axes'].set_position(pos)
639
640                    # Suppress tick labelling for interior subplots.
641                    if i <= (rows-1)*cols - 1:
642                        if i+cols < nplots:
643                            # Suppress x-labels for frames width
644                            # adjacent frames
645                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
646                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
647                    if i%cols:
648                        # Suppress y-labels for frames not in the left column.
649                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
650                            tick.label1On = False
651                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
652                    if (i+1)%cols:
653                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
654                self.rows = rows
655                self.cols = cols
656
657            self.subplot(0)
658
659    def set_title(self, title=None):
660        """
661        Set the title of the plot window.  Use the previous title if title is
662        omitted.
663        """
664        if title is not None:
665            self.title = title
666
667        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
668
669
670    def show(self):
671        """
672        Show graphics dependent on the current buffering state.
673        """
674        if not self.buffering:
675            if self.loc is not None:
676                for j in range(len(self.subplots)):
677                    lines  = []
678                    labels = []
679                    i = 0
680                    for line in self.subplots[j]['lines']:
681                        i += 1
682                        if line is not None:
683                            lines.append(line[0])
684                            lbl = line[0].get_label()
685                            if lbl == '':
686                                lbl = str(i)
687                            labels.append(lbl)
688
689                    if len(lines):
690                        self.subplots[j]['axes'].legend(tuple(lines),
691                                                        tuple(labels),
692                                                        self.loc)
693                    else:
694                        self.subplots[j]['axes'].legend((' '))
695
696
697    def subplot(self, i=None, inc=None):
698        """
699        Set the subplot to the 0-relative panel number as defined by one or
700        more invokations of set_panels().
701        """
702        l = len(self.subplots)
703        if l:
704            if i is not None:
705                self.i = i
706
707            if inc is not None:
708                self.i += inc
709
710            self.i %= l
711            self.axes  = self.subplots[self.i]['axes']
712            self.lines = self.subplots[self.i]['lines']
713
714
715    def text(self, *args, **kwargs):
716        """
717        Add text to the figure.
718        """
719        self.figure.text(*args, **kwargs)
720        self.show()
Note: See TracBrowser for help on using the repository browser.