py4u guide

How to Create Data Visualizations with Python’s Matplotlib

Data visualization is a cornerstone of data analysis, enabling us to transform raw numbers into intuitive insights. Whether you’re exploring trends, comparing datasets, or communicating findings, a well-crafted visualization can make complex information accessible at a glance. In the Python ecosystem, **Matplotlib** stands as the most widely used library for creating static, animated, and interactive visualizations. Developed by John D. Hunter in 2003, Matplotlib is open-source, highly customizable, and integrates seamlessly with other Python libraries like NumPy (for numerical computing) and Pandas (for data manipulation). Its flexibility allows you to generate everything from simple line charts to complex 3D plots, making it a go-to tool for beginners and experts alike. This blog will guide you through the essentials of Matplotlib, from installation to advanced customization, with hands-on examples to help you create publication-ready visualizations.

Table of Contents

  1. Installation
  2. Basic Concepts: Figure and Axes
  3. Common Plot Types
  4. Creating Subplots
  5. Customizing Plots
  6. Saving Plots
  7. Advanced Tips and Best Practices
  8. Conclusion
  9. References

Installation

Before diving in, ensure Matplotlib is installed. Use pip (Python’s package installer) or conda (if using Anaconda):

Using pip:

pip install matplotlib

Using conda:

conda install matplotlib

To verify installation, run this in a Python shell:

import matplotlib
print(matplotlib.__version__)  # Should output the installed version (e.g., 3.8.0)

We’ll also use NumPy for generating sample data, so install it if you haven’t:

pip install numpy  # or conda install numpy

Basic Concepts: Figure and Axes

Matplotlib’s core architecture revolves around two key objects: Figure and Axes. Understanding these is critical to mastering the library.

  • Figure: The outermost container for all plot elements. Think of it as a canvas that can hold one or more plots (called Axes).
  • Axes: The actual plot area where data is rendered. An Axes has an x-axis, y-axis, labels, title, and other elements. A Figure can contain multiple Axes (e.g., subplots).

Creating a Simple Figure and Axes

Use plt.subplots() to create a Figure and Axes object. Here’s a minimal example:

import matplotlib.pyplot as plt
import numpy as np

# Create a figure and a single axes (plot area)
fig, ax = plt.subplots()

# Generate sample data
x = np.linspace(0, 10, 100)  # 100 evenly spaced points between 0 and 10
y = np.sin(x)

# Plot data on the axes
ax.plot(x, y)

# Display the plot
plt.show()
  • fig: The Figure object (canvas).
  • ax: The Axes object (plot area).
  • plt.show(): Renders the plot (required in scripts; not always needed in Jupyter notebooks).

Common Plot Types

Matplotlib supports dozens of plot types. Below are the most widely used, with examples.

Line Plots

Line plots are ideal for showing trends over time (e.g., stock prices, temperature changes). Use ax.plot(x, y).

Example: Plotting a sine wave with customization

fig, ax = plt.subplots(figsize=(10, 4))  # figsize=(width, height) in inches

x = np.linspace(0, 2*np.pi, 100)
y_sin = np.sin(x)
y_cos = np.cos(x)

# Plot sine and cosine curves
ax.plot(x, y_sin, label='sin(x)', color='blue', linestyle='-', linewidth=2)
ax.plot(x, y_cos, label='cos(x)', color='red', linestyle='--', linewidth=2)

# Add labels and title
ax.set_xlabel('X-axis (radians)', fontsize=12)
ax.set_ylabel('Y-axis', fontsize=12)
ax.set_title('Sine and Cosine Waves', fontsize=14, pad=20)

# Add legend and grid
ax.legend(loc='upper right')  # loc specifies legend position
ax.grid(True, linestyle=':', alpha=0.7)  # alpha controls transparency

plt.show()

Scatter Plots

Scatter plots visualize relationships between two numerical variables. Use ax.scatter(x, y).

Example: Correlation between two variables

fig, ax = plt.subplots(figsize=(8, 6))

# Generate random data with a weak positive correlation
np.random.seed(42)  # For reproducibility
x = np.random.randn(100)  # 100 samples from normal distribution
y = 2*x + np.random.randn(100)*0.5  # y = 2x + noise

# Scatter plot with custom markers and size
ax.scatter(x, y, 
           color='green', 
           marker='o',  # 'o'=circle, 's'=square, '^'=triangle
           s=50,  # marker size
           alpha=0.6,  # transparency to avoid overcrowding
           edgecolor='black')  # border around markers

ax.set_xlabel('X Variable', fontsize=12)
ax.set_ylabel('Y Variable', fontsize=12)
ax.set_title('Scatter Plot: Y vs X', fontsize=14)
ax.grid(True)

plt.show()

Bar Plots

Bar plots compare categorical data (e.g., sales by region, exam scores by class). Use ax.bar(categories, values) for vertical bars or ax.barh() for horizontal bars.

Example: Vertical and horizontal bar plots

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # 1 row, 2 columns of axes

# Sample data
categories = ['A', 'B', 'C', 'D', 'E']
values = [25, 40, 30, 35, 20]

# Vertical bar plot
ax1.bar(categories, values, color='skyblue', edgecolor='black')
ax1.set_title('Vertical Bar Plot')
ax1.set_xlabel('Categories')
ax1.set_ylabel('Values')

# Horizontal bar plot
ax2.barh(categories, values, color='salmon', edgecolor='black')
ax2.set_title('Horizontal Bar Plot')
ax2.set_xlabel('Values')
ax2.set_ylabel('Categories')

plt.tight_layout()  # Prevents overlapping
plt.show()

Histograms

Histograms show the distribution of a single numerical variable by grouping data into bins. Use ax.hist(data, bins).

Example: Distribution of heights

fig, ax = plt.subplots(figsize=(8, 5))

# Generate data: heights (mean=170cm, std=10cm)
np.random.seed(42)
heights = np.random.normal(loc=170, scale=10, size=1000)  # Normal distribution

# Plot histogram
n, bins, patches = ax.hist(
    heights, 
    bins=20,  # Number of bins
    color='purple', 
    edgecolor='white', 
    alpha=0.7
)

ax.set_title('Distribution of Heights (cm)', fontsize=14)
ax.set_xlabel('Height (cm)')
ax.set_ylabel('Frequency')
ax.grid(axis='y', linestyle='--', alpha=0.7)  # Grid only on y-axis

plt.show()

Pie Charts

Pie charts display proportions of a whole (e.g., market share). Use ax.pie(sizes, labels).

Note: Avoid pie charts for more than 5-6 categories—they become hard to read. Use bar plots instead for clarity.

Example: Market share of smartphone brands

fig, ax = plt.subplots(figsize=(6, 6))  # Square figure for circular pie

sizes = [35, 25, 20, 15, 5]  # Proportions (sum to 100%)
labels = ['Apple', 'Samsung', 'Huawei', 'Xiaomi', 'Others']
colors = ['#ff9999', '#66b3ff', '#99ff99', '#ffcc99', '#c2c2f0']

# Explode the largest slice (Apple) for emphasis
explode = (0.1, 0, 0, 0, 0)  # Explode 1st slice by 10%

ax.pie(
    sizes, 
    explode=explode, 
    labels=labels, 
    colors=colors, 
    autopct='%1.1f%%',  # Display percentages
    shadow=True,  # Add shadow
    startangle=90  # Rotate pie to start at 90 degrees
)

ax.set_title('Smartphone Market Share', fontsize=14)
ax.axis('equal')  # Ensures pie is drawn as a circle

plt.show()

Creating Subplots

Subplots let you display multiple plots in a single figure. Use plt.subplots(nrows, ncols) to create a grid of Axes.

Example: 2x2 grid of plots

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 8))  # 2x2 grid

# Flatten axes for easy iteration (optional)
axes = axes.flatten()

# Plot 1: Line plot
x = np.linspace(0, 10, 100)
axes[0].plot(x, np.sin(x), color='blue')
axes[0].set_title('Sine Wave')

# Plot 2: Scatter plot
axes[1].scatter(np.random.randn(50), np.random.randn(50), color='red')
axes[1].set_title('Random Scatter')

# Plot 3: Bar plot
axes[2].bar(['A', 'B', 'C'], [10, 20, 15], color='green')
axes[2].set_title('Simple Bar Plot')

# Plot 4: Histogram
axes[3].hist(np.random.randn(1000), bins=30, color='purple')
axes[3].set_title('Normal Distribution')

plt.tight_layout(pad=2)  # Add padding between subplots
plt.show()

Customizing Plots

Matplotlib’s true power lies in customization. Below are key techniques to refine your plots.

Colors, Markers, and Line Styles

  • Colors: Use named colors ('red', 'blue'), hex codes ('#FF5733'), or RGB tuples ((0.2, 0.4, 0.6)).
  • Line styles: '-' (solid), '--' (dashed), ':' (dotted), '-.' (dash-dot).
  • Markers: 'o' (circle), 's' (square), '^' (triangle), '*' (star).

Example: Custom line plot

fig, ax = plt.subplots(figsize=(8, 4))

x = np.linspace(0, 5, 20)
y = x**2

ax.plot(
    x, y, 
    color='#2E8B57',  # Sea green (hex code)
    linestyle='-.', 
    linewidth=2,
    marker='^', 
    markersize=8, 
    markerfacecolor='yellow',  # Marker fill color
    markeredgecolor='black'    # Marker border color
)

ax.set_title('Customized Line Plot')
plt.show()

Labels, Titles, and Legends

  • Labels: Use ax.set_xlabel() and ax.set_ylabel() for axis labels.
  • Title: Use ax.set_title(), with pad to adjust spacing from the plot.
  • Legends: Use ax.legend() to explain multiple lines/bars. Pass label to ax.plot() for legend entries.

Example: Legend with custom positioning

fig, ax = plt.subplots()

ax.plot([1, 2, 3], [2, 4, 1], label='Line 1')
ax.plot([1, 2, 3], [5, 1, 3], label='Line 2')

ax.legend(loc='lower center', ncol=2, fontsize=10, title='Legend Title')  
# loc: position; ncol: number of columns; title: legend title

Ticks, Grid Lines, and Annotations

  • Ticks: Rotate x-ticks for long labels with plt.xticks(rotation=45).
  • Grid lines: Use ax.grid(axis='x') or axis='y' to show grids on specific axes.
  • Annotations: Add text or arrows with ax.text(x, y, 'text') or ax.annotate().

Example: Annotating a plot

fig, ax = plt.subplots(figsize=(8, 5))

x = np.linspace(0, 10, 100)
y = np.sin(x)

ax.plot(x, y, color='blue')
ax.set_title('Annotated Sine Wave')

# Highlight the peak (x=π/2, y=1)
peak_x = np.pi/2
peak_y = 1
ax.plot(peak_x, peak_y, 'ro')  # Red circle marker
ax.annotate(
    'Peak (π/2, 1)', 
    xy=(peak_x, peak_y),  # Point to annotate
    xytext=(peak_x + 1, peak_y + 0.2),  # Text position
    arrowprops=dict(facecolor='black', shrink=0.05)  # Arrow style
)

# Rotate x-ticks
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Saving Plots

Save plots to files using fig.savefig(). Supported formats: PNG, PDF, SVG, EPS.

Example: Saving a plot

fig, ax = plt.subplots()
ax.plot([1, 2, 3], [4, 1, 3])
ax.set_title('Plot to Save')

# Save as PNG (high resolution)
fig.savefig('my_plot.png', dpi=300, bbox_inches='tight')  # dpi=300 for print

# Save as PDF (vector format, scalable)
fig.savefig('my_plot.pdf', bbox_inches='tight')

# Save as SVG (vector format for web)
fig.savefig('my_plot.svg', bbox_inches='tight')
  • dpi: Dots per inch (resolution; 300 is standard for print).
  • bbox_inches='tight': Ensures labels/legends aren’t cut off.

Advanced Tips and Best Practices

  1. Use Styles: Matplotlib has built-in styles (e.g., seaborn, ggplot) to美化 plots. Enable with plt.style.use('seaborn-v0_8-whitegrid').

    plt.style.use('seaborn-v0_8-darkgrid')  # Apply style
    fig, ax = plt.subplots()
    ax.plot([1, 2, 3], [4, 1, 3])
    plt.show()
  2. Pandas Integration: Pandas DataFrames have a .plot() method that wraps Matplotlib for quick plotting:

    import pandas as pd
    df = pd.DataFrame({'A': [1, 3, 2], 'B': [4, 2, 5]})
    df.plot(kind='bar', figsize=(8, 4))  # Uses Matplotlib under the hood
  3. 3D Plots: Use the mplot3d toolkit for 3D visualizations:

    from mpl_toolkits import mplot3d
    
    fig = plt.figure(figsize=(8, 6))
    ax = plt.axes(projection='3d')  # 3D Axes
    x = np.linspace(-5, 5, 100)
    y = np.linspace(-5, 5, 100)
    X, Y = np.meshgrid(x, y)
    Z = np.sin(np.sqrt(X**2 + Y**2))  # 3D sine wave
    ax.plot_surface(X, Y, Z, cmap='viridis')  # cmap: color map
  4. Best Practices:

    • Keep plots simple: Avoid clutter (e.g., too many colors, grid lines).
    • Use colorblind-friendly palettes (e.g., 'viridis', 'plasma' colormaps).
    • Label all axes and include units (e.g., “Time (seconds)”).

Conclusion

Matplotlib is a versatile library for creating static visualizations in Python. By mastering its core concepts (Figure, Axes), common plot types, and customization techniques, you can generate clear, informative plots for analysis, presentations, or publications.

Start with simple plots, experiment with customization, and gradually explore advanced features like 3D plots or integration with Pandas. For even more polished visuals, check out Seaborn—a library built on Matplotlib that simplifies statistical plotting.

References