source: trunk/python/asaplotbase.py @ 1023

Last change on this file since 1023 was 1023, checked in by mar637, 18 years ago

The previous histogram plot was mutually exclusive with linestyle, so I am using asapplotbase.hist intead.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 23.3 KB
Line 
1"""
2ASAP plotting class based on matplotlib.
3"""
4
5import sys
6from re import match
7
8import matplotlib
9
10from matplotlib.figure import Figure, Text
11from matplotlib.font_manager import FontProperties
12from matplotlib.numerix import sqrt
13from matplotlib import rc, rcParams
14from asap import rcParams as asaprcParams
15
16from matplotlib.ticker import ScalarFormatter
17from matplotlib.ticker import NullLocator
18
19class MyFormatter(ScalarFormatter):
20    def __call__(self, x, pos=None):
21        last = len(self.locs)-2
22        if pos==last:
23            print "Diabling tick no " , pos, last
24            return ''  # pos=-1 is the last tick
25        else: return ScalarFormatter.__call__(self, x, pos)
26
27class asaplotbase:
28    """
29    ASAP plotting base class based on matplotlib.
30    """
31
32    def __init__(self, rows=1, cols=0, title='', size=(8,6), buffering=False):
33        """
34        Create a new instance of the ASAPlot plotting class.
35
36        If rows < 1 then a separate call to set_panels() is required to define
37        the panel layout; refer to the doctext for set_panels().
38        """
39        self.is_dead = False
40        self.figure = Figure(figsize=size, facecolor='#ddddee')
41        self.canvas = None
42
43        self.set_title(title)
44        self.subplots = []
45        if rows > 0:
46            self.set_panels(rows, cols)
47
48        # Set matplotlib default colour sequence.
49        self.colormap = "green red black cyan magenta orange blue purple yellow pink".split()
50
51        c = asaprcParams['plotter.colours']
52        if isinstance(c,str) and len(c) > 0:
53            self.colormap = c.split()
54
55        self.lsalias = {"line":  [1,0],
56                        "dashdot": [4,2,1,2],
57                        "dashed" : [4,2,4,2],
58                        "dotted" : [1,2],
59                        "dashdotdot": [4,2,1,2,1,2],
60                        "dashdashdot": [4,2,4,2,1,2]
61                        }
62
63        styles = "line dashed dotted dashdot".split()
64        c = asaprcParams['plotter.linestyles']
65        if isinstance(c,str) and len(c) > 0:
66            styles = c.split()
67        s = []
68        for ls in styles:
69            if self.lsalias.has_key(ls):
70                s.append(self.lsalias.get(ls))
71            else:
72                s.append('-')
73        self.linestyles = s
74
75        self.color = 0;
76        self.linestyle = 0;
77        self.attributes = {}
78        self.loc = 0
79
80        self.buffering = buffering
81
82    def clear(self):
83        """
84        Delete all lines from the plot.  Line numbering will restart from 1.
85        """
86
87        for i in range(len(self.lines)):
88           self.delete(i)
89        self.axes.clear()
90        self.color = 0
91        self.lines = []
92
93    def palette(self, color, colormap=None, linestyle=0, linestyles=None):
94        if colormap:
95            if isinstance(colormap,list):
96                self.colormap = colormap
97            elif isinstance(colormap,str):
98                self.colormap = colormap.split()
99        if 0 <= color < len(self.colormap):
100            self.color = color
101        if linestyles:
102            self.linestyles = []
103            if isinstance(linestyles,list):
104                styles = linestyles
105            elif isinstance(linestyles,str):
106                styles = linestyles.split()
107            for ls in styles:
108                if self.lsalias.has_key(ls):
109                    self.linestyles.append(self.lsalias.get(ls))
110                else:
111                    self.linestyles.append(self.lsalias.get('line'))
112        if 0 <= linestyle < len(self.linestyles):
113            self.linestyle = linestyle
114
115    def delete(self, numbers=None):
116        """
117        Delete the 0-relative line number, default is to delete the last.
118        The remaining lines are NOT renumbered.
119        """
120
121        if numbers is None: numbers = [len(self.lines)-1]
122
123        if not hasattr(numbers, '__iter__'):
124            numbers = [numbers]
125
126        for number in numbers:
127            if 0 <= number < len(self.lines):
128                if self.lines[number] is not None:
129                    for line in self.lines[number]:
130                        line.set_linestyle('None')
131                        self.lines[number] = None
132        self.show()
133
134    def get_line(self):
135        """
136        Get the current default line attributes.
137        """
138        return self.attributes
139
140
141    def hist(self, x=None, y=None, msk=None, fmt=None, add=None):
142        """
143        Plot a histogram.  N.B. the x values refer to the start of the
144        histogram bin.
145
146        fmt is the line style as in plot().
147        """
148
149        if x is None:
150            if y is None: return
151            x = range(len(y))
152
153        if len(x) != len(y):
154            return
155        l2 = 2*len(x)
156        x2 = range(l2)
157        y2 = range(l2)
158        m2 = range(l2)
159
160        for i in range(l2):
161            x2[i] = x[i/2]
162            m2[i] = msk[i/2]
163
164        y2[0] = 0.0
165        for i in range(1,l2):
166            y2[i] = y[(i-1)/2]
167
168        self.plot(x2, y2, m2, fmt, add)
169
170
171    def hold(self, hold=True):
172        """
173        Buffer graphics until subsequently released.
174        """
175        self.buffering = hold
176
177
178    def legend(self, loc=None):
179        """
180        Add a legend to the plot.
181
182        Any other value for loc else disables the legend:
183             1: upper right
184             2: upper left
185             3: lower left
186             4: lower right
187             5: right
188             6: center left
189             7: center right
190             8: lower center
191             9: upper center
192            10: center
193
194        """
195        if isinstance(loc,int):
196            if 0 > loc > 10: loc = 0
197            self.loc = loc
198        self.show()
199
200
201    def plot(self, x=None, y=None, mask=None, fmt=None, add=None):
202        """
203        Plot the next line in the current frame using the current line
204        attributes.  The ASAPlot graphics window will be mapped and raised.
205
206        The argument list works a bit like the matlab plot() function.
207        """
208
209        if x is None:
210            if y is None: return
211            x = range(len(y))
212
213        elif y is None:
214            y = x
215            x = range(len(y))
216
217        if mask is None:
218            if fmt is None:
219                line = self.axes.plot(x, y)
220            else:
221                line = self.axes.plot(x, y, fmt)
222        else:
223            segments = []
224
225            mask = list(mask)
226            i = 0
227            while mask[i:].count(1):
228                i += mask[i:].index(1)
229                if mask[i:].count(0):
230                    j = i + mask[i:].index(0)
231                else:
232                    j = len(mask)
233
234                segments.append(x[i:j])
235                segments.append(y[i:j])
236
237                i = j
238
239            line = self.axes.plot(*segments)
240        # Add to an existing line?
241        if add is None or len(self.lines) < add < 0:
242            # Don't add.
243            self.lines.append(line)
244            i = len(self.lines) - 1
245        else:
246            if add == 0: add = len(self.lines)
247            i = add - 1
248            self.lines[i].extend(line)
249
250        # Set/reset attributes for the line.
251        gotcolour = False
252        for k, v in self.attributes.iteritems():
253            if k == 'color': gotcolour = True
254            for segment in self.lines[i]:
255                getattr(segment, "set_%s"%k)(v)
256
257        if not gotcolour and len(self.colormap):
258            for segment in self.lines[i]:
259                getattr(segment, "set_color")(self.colormap[self.color])
260                if len(self.colormap)  == 1:
261                    getattr(segment, "set_dashes")(self.linestyles[self.linestyle])
262            self.color += 1
263            if self.color >= len(self.colormap):
264                self.color = 0
265
266            if len(self.colormap) == 1:
267                self.linestyle += 1
268            if self.linestyle >= len(self.linestyles):
269                self.linestyle = 0
270
271        self.show()
272
273
274    def position(self):
275        """
276        Use the mouse to get a position from a graph.
277        """
278
279        def position_disable(event):
280            self.register('button_press', None)
281            print '%.4f, %.4f' % (event.xdata, event.ydata)
282
283        print 'Press any mouse button...'
284        self.register('button_press', position_disable)
285
286
287    def region(self):
288        """
289        Use the mouse to get a rectangular region from a plot.
290
291        The return value is [x0, y0, x1, y1] in world coordinates.
292        """
293
294        def region_start(event):
295            height = self.canvas.figure.bbox.height()
296            self.rect = {'fig': None, 'height': height,
297                         'x': event.x, 'y': height - event.y,
298                         'world': [event.xdata, event.ydata,
299                                   event.xdata, event.ydata]}
300            self.register('button_press', None)
301            self.register('motion_notify', region_draw)
302            self.register('button_release', region_disable)
303
304        def region_draw(event):
305            self.canvas._tkcanvas.delete(self.rect['fig'])
306            self.rect['fig'] = self.canvas._tkcanvas.create_rectangle(
307                                self.rect['x'], self.rect['y'],
308                                event.x, self.rect['height'] - event.y)
309
310        def region_disable(event):
311            self.register('motion_notify', None)
312            self.register('button_release', None)
313
314            self.canvas._tkcanvas.delete(self.rect['fig'])
315
316            self.rect['world'][2:4] = [event.xdata, event.ydata]
317            print '(%.2f, %.2f)  (%.2f, %.2f)' % (self.rect['world'][0],
318                self.rect['world'][1], self.rect['world'][2],
319                self.rect['world'][3])
320
321        self.register('button_press', region_start)
322
323        # This has to be modified to block and return the result (currently
324        # printed by region_disable) when that becomes possible in matplotlib.
325
326        return [0.0, 0.0, 0.0, 0.0]
327
328
329    def register(self, type=None, func=None):
330        """
331        Register, reregister, or deregister events of type 'button_press',
332        'button_release', or 'motion_notify'.
333
334        The specified callback function should have the following signature:
335
336            def func(event)
337
338        where event is an MplEvent instance containing the following data:
339
340            name                # Event name.
341            canvas              # FigureCanvas instance generating the event.
342            x      = None       # x position - pixels from left of canvas.
343            y      = None       # y position - pixels from bottom of canvas.
344            button = None       # Button pressed: None, 1, 2, 3.
345            key    = None       # Key pressed: None, chr(range(255)), shift,
346                                  win, or control
347            inaxes = None       # Axes instance if cursor within axes.
348            xdata  = None       # x world coordinate.
349            ydata  = None       # y world coordinate.
350
351        For example:
352
353            def mouse_move(event):
354                print event.xdata, event.ydata
355
356            a = asaplot()
357            a.register('motion_notify', mouse_move)
358
359        If func is None, the event is deregistered.
360
361        Note that in TkAgg keyboard button presses don't generate an event.
362        """
363
364        if not self.events.has_key(type): return
365
366        if func is None:
367            if self.events[type] is not None:
368                # It's not clear that this does anything.
369                self.canvas.mpl_disconnect(self.events[type])
370                self.events[type] = None
371
372                # It seems to be necessary to return events to the toolbar.
373                if type == 'motion_notify':
374                    self.canvas.mpl_connect(type + '_event',
375                        self.figmgr.toolbar.mouse_move)
376                elif type == 'button_press':
377                    self.canvas.mpl_connect(type + '_event',
378                        self.figmgr.toolbar.press)
379                elif type == 'button_release':
380                    self.canvas.mpl_connect(type + '_event',
381                        self.figmgr.toolbar.release)
382
383        else:
384            self.events[type] = self.canvas.mpl_connect(type + '_event', func)
385
386
387    def release(self):
388        """
389        Release buffered graphics.
390        """
391        self.buffering = False
392        self.show()
393
394
395    def save(self, fname=None, orientation=None, dpi=None):
396        """
397        Save the plot to a file.
398
399        fname is the name of the output file.  The image format is determined
400        from the file suffix; 'png', 'ps', and 'eps' are recognized.  If no
401        file name is specified 'yyyymmdd_hhmmss.png' is created in the current
402        directory.
403        """
404        if fname is None:
405            from datetime import datetime
406            dstr = datetime.now().strftime('%Y%m%d_%H%M%S')
407            fname = 'asap'+dstr+'.png'
408
409        d = ['png','.ps','eps']
410
411        from os.path import expandvars
412        fname = expandvars(fname)
413
414        if fname[-3:].lower() in d:
415            try:
416                if fname[-3:].lower() == ".ps":
417                    from matplotlib import __version__ as mv
418                    w = self.figure.figwidth.get()
419                    h = self.figure.figheight.get()
420
421                    if orientation is None:
422                        # auto oriented
423                        if w > h:
424                            orientation = 'landscape'
425                        else:
426                            orientation = 'portrait'
427                    # hack to circument ps bug in eraly versions of mpl
428                    if int(mv.split(".")[1]) < 86:
429                        a4w = 8.25
430                        a4h = 11.25
431                        ds = None
432                        if orientation == 'landscape':
433                            ds = min(a4h/w,a4w/h)
434                        else:
435                            ds = min(a4w/w,a4h/h)
436                        ow = ds * w
437                        oh = ds * h
438                        self.figure.set_figsize_inches((ow,oh))
439                    self.canvas.print_figure(fname,orientation=orientation)
440                    print 'Written file %s' % (fname)
441                else:
442                    if dpi is None:
443                        dpi =150
444                    self.canvas.print_figure(fname,dpi=dpi)
445                    print 'Written file %s' % (fname)
446            except IOError, msg:
447                print 'Failed to save %s: Error msg was\n\n%s' % (fname, err)
448                return
449        else:
450            print "Invalid image type. Valid types are:"
451            print "'ps', 'eps', 'png'"
452
453
454    def set_axes(self, what=None, *args, **kwargs):
455        """
456        Set attributes for the axes by calling the relevant Axes.set_*()
457        method.  Colour translation is done as described in the doctext
458        for palette().
459        """
460
461        if what is None: return
462        if what[-6:] == 'colour': what = what[:-6] + 'color'
463
464        newargs = {}
465
466        for k, v in kwargs.iteritems():
467            k = k.lower()
468            if k == 'colour': k = 'color'
469            newargs[k] = v
470
471        getattr(self.axes, "set_%s"%what)(*args, **newargs)
472        s = self.axes.title.get_size()
473        tsize = s-(self.cols+self.rows)/2-1
474        self.axes.title.set_size(tsize)
475        if self.cols > 1:
476            xfsize = self.axes.xaxis.label.get_size()-(self.cols+1)/2
477            self.axes.xaxis.label.set_size(xfsize)
478        if self.rows > 1:
479            yfsize = self.axes.yaxis.label.get_size()-(self.rows+1)/2
480            self.axes.yaxis.label.set_size(yfsize)
481
482        self.show()
483
484
485    def set_figure(self, what=None, *args, **kwargs):
486        """
487        Set attributes for the figure by calling the relevant Figure.set_*()
488        method.  Colour translation is done as described in the doctext
489        for palette().
490        """
491
492        if what is None: return
493        if what[-6:] == 'colour': what = what[:-6] + 'color'
494        #if what[-5:] == 'color' and len(args):
495        #    args = (get_colour(args[0]),)
496
497        newargs = {}
498        for k, v in kwargs.iteritems():
499            k = k.lower()
500            if k == 'colour': k = 'color'
501            newargs[k] = v
502
503        getattr(self.figure, "set_%s"%what)(*args, **newargs)
504        self.show()
505
506
507    def set_limits(self, xlim=None, ylim=None):
508        """
509        Set x-, and y-limits for each subplot.
510
511        xlim = [xmin, xmax] as in axes.set_xlim().
512        ylim = [ymin, ymax] as in axes.set_ylim().
513        """
514        for s in self.subplots:
515            self.axes  = s['axes']
516            self.lines = s['lines']
517            oldxlim =  list(self.axes.get_xlim())
518            oldylim =  list(self.axes.get_ylim())
519            if xlim is not None:
520                for i in range(len(xlim)):
521                    if xlim[i] is not None:
522                        oldxlim[i] = xlim[i]
523            if ylim is not None:
524                for i in range(len(ylim)):
525                    if ylim[i] is not None:
526                        oldylim[i] = ylim[i]
527            self.axes.set_xlim(oldxlim)
528            self.axes.set_ylim(oldylim)
529        return
530
531
532    def set_line(self, number=None, **kwargs):
533        """
534        Set attributes for the specified line, or else the next line(s)
535        to be plotted.
536
537        number is the 0-relative number of a line that has already been
538        plotted.  If no such line exists, attributes are recorded and used
539        for the next line(s) to be plotted.
540
541        Keyword arguments specify Line2D attributes, e.g. color='r'.  Do
542
543            import matplotlib
544            help(matplotlib.lines)
545
546        The set_* methods of class Line2D define the attribute names and
547        values.  For non-US usage, "colour" is recognized as synonymous with
548        "color".
549
550        Set the value to None to delete an attribute.
551
552        Colour translation is done as described in the doctext for palette().
553        """
554
555        redraw = False
556        for k, v in kwargs.iteritems():
557            k = k.lower()
558            if k == 'colour': k = 'color'
559
560            if 0 <= number < len(self.lines):
561                if self.lines[number] is not None:
562                    for line in self.lines[number]:
563                        getattr(line, "set_%s"%k)(v)
564                    redraw = True
565            else:
566                if v is None:
567                    del self.attributes[k]
568                else:
569                    self.attributes[k] = v
570
571        if redraw: self.show()
572
573
574    def set_panels(self, rows=1, cols=0, n=-1, nplots=-1, ganged=True):
575        """
576        Set the panel layout.
577
578        rows and cols, if cols != 0, specify the number of rows and columns in
579        a regular layout.   (Indexing of these panels in matplotlib is row-
580        major, i.e. column varies fastest.)
581
582        cols == 0 is interpreted as a retangular layout that accomodates
583        'rows' panels, e.g. rows == 6, cols == 0 is equivalent to
584        rows == 2, cols == 3.
585
586        0 <= n < rows*cols is interpreted as the 0-relative panel number in
587        the configuration specified by rows and cols to be added to the
588        current figure as its next 0-relative panel number (i).  This allows
589        non-regular panel layouts to be constructed via multiple calls.  Any
590        other value of n clears the plot and produces a rectangular array of
591        empty panels.  The number of these may be limited by nplots.
592        """
593        if n < 0 and len(self.subplots):
594            self.figure.clear()
595            self.set_title()
596
597        if rows < 1: rows = 1
598
599        if cols <= 0:
600            i = int(sqrt(rows))
601            if i*i < rows: i += 1
602            cols = i
603
604            if i*(i-1) >= rows: i -= 1
605            rows = i
606
607        if 0 <= n < rows*cols:
608            i = len(self.subplots)
609            self.subplots.append({})
610
611            self.subplots[i]['axes']  = self.figure.add_subplot(rows,
612                                            cols, n+1)
613            self.subplots[i]['lines'] = []
614
615            if i == 0: self.subplot(0)
616
617            self.rows = 0
618            self.cols = 0
619
620        else:
621            self.subplots = []
622
623            if nplots < 1 or rows*cols < nplots:
624                nplots = rows*cols
625
626            for i in range(nplots):
627                self.subplots.append({})
628
629                self.subplots[i]['axes']  = self.figure.add_subplot(rows,
630                                                cols, i+1)
631                self.subplots[i]['lines'] = []
632
633                if ganged:
634                    if rows > 1 or cols > 1:
635                        # Squeeze the plots together.
636                        pos = self.subplots[i]['axes'].get_position()
637                        if cols > 1: pos[2] *= 1.2
638                        if rows > 1: pos[3] *= 1.2
639                        self.subplots[i]['axes'].set_position(pos)
640
641                    # Suppress tick labelling for interior subplots.
642                    if i <= (rows-1)*cols - 1:
643                        if i+cols < nplots:
644                            # Suppress x-labels for frames width
645                            # adjacent frames
646                            self.subplots[i]['axes'].xaxis.set_major_locator(NullLocator())
647                            self.subplots[i]['axes'].xaxis.label.set_visible(False)
648                    if i%cols:
649                        # Suppress y-labels for frames not in the left column.
650                        for tick in self.subplots[i]['axes'].yaxis.majorTicks:
651                            tick.label1On = False
652                        self.subplots[i]['axes'].yaxis.label.set_visible(False)
653                    if (i+1)%cols:
654                        self.subplots[i]['axes'].xaxis.set_major_formatter(MyFormatter())
655                self.rows = rows
656                self.cols = cols
657
658            self.subplot(0)
659
660    def set_title(self, title=None):
661        """
662        Set the title of the plot window.  Use the previous title if title is
663        omitted.
664        """
665        if title is not None:
666            self.title = title
667
668        self.figure.text(0.5, 0.95, self.title, horizontalalignment='center')
669
670
671    def show(self):
672        """
673        Show graphics dependent on the current buffering state.
674        """
675        if not self.buffering:
676            if self.loc is not None:
677                for j in range(len(self.subplots)):
678                    lines  = []
679                    labels = []
680                    i = 0
681                    for line in self.subplots[j]['lines']:
682                        i += 1
683                        if line is not None:
684                            lines.append(line[0])
685                            lbl = line[0].get_label()
686                            if lbl == '':
687                                lbl = str(i)
688                            labels.append(lbl)
689
690                    if len(lines):
691                        self.subplots[j]['axes'].legend(tuple(lines),
692                                                        tuple(labels),
693                                                        self.loc)
694                    else:
695                        self.subplots[j]['axes'].legend((' '))
696
697
698    def subplot(self, i=None, inc=None):
699        """
700        Set the subplot to the 0-relative panel number as defined by one or
701        more invokations of set_panels().
702        """
703        l = len(self.subplots)
704        if l:
705            if i is not None:
706                self.i = i
707
708            if inc is not None:
709                self.i += inc
710
711            self.i %= l
712            self.axes  = self.subplots[self.i]['axes']
713            self.lines = self.subplots[self.i]['lines']
714
715
716    def text(self, *args, **kwargs):
717        """
718        Add text to the figure.
719        """
720        self.figure.text(*args, **kwargs)
721        self.show()
Note: See TracBrowser for help on using the repository browser.