Chapter 7 - Mock and Test Doubles

Haiyue
26min

Chapter 7: Mock and Test Doubles

Learning Objectives
  • Understand Mock concepts and use cases
  • Master the use of unittest.mock
  • Learn dependency injection and test isolation
  • Master the use of pytest-mock plugin

Knowledge Points

Mock Concepts and Uses

Mock (mock object) is a test double technique used for:

  • Dependency isolation: Isolate code under test from external dependencies
  • Behavior control: Precisely control dependency behavior and return values
  • Interaction verification: Verify method call counts and parameters
  • Improve speed: Avoid real network, database calls

Test Double Types

TypeDescriptionUse Case
DummyPlaceholder object, not usedSatisfy interface requirements
FakeSimplified working implementationIn-memory database
StubObject with preset responsesFixed return values
SpyObject that records call informationVerify interactions
MockFully controllable test doubleComplex interaction testing

Mock Tool Comparison

# unittest.mock - Built into Python
from unittest.mock import Mock, patch, MagicMock

# pytest-mock - pytest integration
pytest.MonkeyPatch, mocker fixture

# responses - HTTP request mock
import responses

Example Code

Basic Mock Usage

# test_mock_basic.py
import pytest
from unittest.mock import Mock, MagicMock, patch

# Code under test
class EmailService:
    """Email service"""
    def __init__(self, smtp_client):
        self.smtp_client = smtp_client

    def send_email(self, to_email, subject, body):
        """Send email"""
        try:
            result = self.smtp_client.send(
                to=to_email,
                subject=subject,
                body=body
            )
            return {"status": "success", "message_id": result.message_id}
        except Exception as e:
            return {"status": "error", "error": str(e)}

    def send_bulk_emails(self, email_list):
        """Send bulk emails"""
        results = []
        for email_data in email_list:
            result = self.send_email(
                email_data["to"],
                email_data["subject"],
                email_data["body"]
            )
            results.append(result)
        return results

class TestEmailService:
    """Email service tests"""

    def test_send_email_success(self):
        """Test successful email sending"""
        # Create mock object
        mock_smtp_client = Mock()
        mock_result = Mock()
        mock_result.message_id = "msg_123456"
        mock_smtp_client.send.return_value = mock_result

        # Create service instance
        email_service = EmailService(mock_smtp_client)

        # Execute test
        result = email_service.send_email(
            "test@example.com",
            "Test Subject",
            "Test Body"
        )

        # Verify result
        assert result["status"] == "success"
        assert result["message_id"] == "msg_123456"

        # Verify mock call
        mock_smtp_client.send.assert_called_once_with(
            to="test@example.com",
            subject="Test Subject",
            body="Test Body"
        )

    def test_send_email_failure(self):
        """Test email sending failure"""
        # Create mock that throws exception
        mock_smtp_client = Mock()
        mock_smtp_client.send.side_effect = ConnectionError("SMTP server unavailable")

        email_service = EmailService(mock_smtp_client)

        result = email_service.send_email(
            "test@example.com",
            "Test Subject",
            "Test Body"
        )

        # Verify error handling
        assert result["status"] == "error"
        assert "SMTP server unavailable" in result["error"]

    def test_bulk_email_sending(self):
        """Test bulk email sending"""
        mock_smtp_client = Mock()
        mock_result = Mock()
        mock_result.message_id = "msg_123"
        mock_smtp_client.send.return_value = mock_result

        email_service = EmailService(mock_smtp_client)

        email_list = [
            {"to": "user1@example.com", "subject": "Subject 1", "body": "Body 1"},
            {"to": "user2@example.com", "subject": "Subject 2", "body": "Body 2"}
        ]

        results = email_service.send_bulk_emails(email_list)

        # Verify results
        assert len(results) == 2
        assert all(r["status"] == "success" for r in results)

        # Verify call count
        assert mock_smtp_client.send.call_count == 2

Using patch Decorator

# test_patch_decorator.py
import requests
from unittest.mock import patch, Mock

# Code under test
class APIClient:
    """API client"""
    def __init__(self, base_url):
        self.base_url = base_url

    def get_user(self, user_id):
        """Get user information"""
        response = requests.get(f"{self.base_url}/users/{user_id}")
        response.raise_for_status()
        return response.json()

    def create_user(self, user_data):
        """Create user"""
        response = requests.post(
            f"{self.base_url}/users",
            json=user_data
        )
        response.raise_for_status()
        return response.json()

    def get_user_with_retry(self, user_id, max_retries=3):
        """Get user with retry"""
        import time

        for attempt in range(max_retries):
            try:
                return self.get_user(user_id)
            except requests.RequestException:
                if attempt == max_retries - 1:
                    raise
                time.sleep(1)

class TestAPIClient:
    """API client tests"""

    @patch('requests.get')
    def test_get_user_success(self, mock_get):
        """Test successful user retrieval"""
        # Set mock return value
        mock_response = Mock()
        mock_response.json.return_value = {
            "id": 1,
            "name": "John Doe",
            "email": "john@example.com"
        }
        mock_response.raise_for_status.return_value = None
        mock_get.return_value = mock_response

        # Execute test
        client = APIClient("https://api.example.com")
        user = client.get_user(1)

        # Verify result
        assert user["id"] == 1
        assert user["name"] == "John Doe"
        assert user["email"] == "john@example.com"

        # Verify call
        mock_get.assert_called_once_with("https://api.example.com/users/1")

    @patch('requests.get')
    def test_get_user_not_found(self, mock_get):
        """Test user not found"""
        # Set mock to throw exception
        mock_response = Mock()
        mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
        mock_get.return_value = mock_response

        client = APIClient("https://api.example.com")

        with pytest.raises(requests.HTTPError):
            client.get_user(999)

    @patch('requests.post')
    def test_create_user(self, mock_post):
        """Test user creation"""
        # Set mock return value
        mock_response = Mock()
        mock_response.json.return_value = {
            "id": 2,
            "name": "Jane Doe",
            "email": "jane@example.com"
        }
        mock_response.raise_for_status.return_value = None
        mock_post.return_value = mock_response

        client = APIClient("https://api.example.com")
        user_data = {"name": "Jane Doe", "email": "jane@example.com"}

        result = client.create_user(user_data)

        # Verify result
        assert result["id"] == 2
        assert result["name"] == "Jane Doe"

        # Verify call parameters
        mock_post.assert_called_once_with(
            "https://api.example.com/users",
            json=user_data
        )

    @patch('time.sleep')
    @patch('requests.get')
    def test_get_user_with_retry(self, mock_get, mock_sleep):
        """Test retry mechanism"""
        # Set first two calls to fail, third to succeed
        mock_response_success = Mock()
        mock_response_success.json.return_value = {"id": 1, "name": "John"}
        mock_response_success.raise_for_status.return_value = None

        mock_get.side_effect = [
            requests.RequestException("Network error"),
            requests.RequestException("Timeout"),
            mock_response_success
        ]

        client = APIClient("https://api.example.com")
        result = client.get_user_with_retry(1)

        # Verify result
        assert result["id"] == 1
        assert result["name"] == "John"

        # Verify retry count
        assert mock_get.call_count == 3
        assert mock_sleep.call_count == 2  # sleep after first two failures

pytest-mock Plugin Usage

# test_pytest_mock.py
# Requires installation: pip install pytest-mock

import requests
import os
from pathlib import Path

# Code under test
class FileManager:
    """File manager"""

    def read_config(self, config_path):
        """Read configuration file"""
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")

        with open(config_path, 'r') as f:
            return f.read()

    def write_log(self, message):
        """Write log"""
        log_path = Path("app.log")
        with open(log_path, 'a') as f:
            import datetime
            timestamp = datetime.datetime.now().isoformat()
            f.write(f"[{timestamp}] {message}\n")

class WeatherService:
    """Weather service"""

    def __init__(self, api_key):
        self.api_key = api_key
        self.base_url = "https://api.weather.com"

    def get_current_weather(self, city):
        """Get current weather"""
        url = f"{self.base_url}/current"
        params = {
            "city": city,
            "key": self.api_key
        }

        response = requests.get(url, params=params)
        response.raise_for_status()

        data = response.json()
        return {
            "city": data["location"]["city"],
            "temperature": data["current"]["temp_c"],
            "description": data["current"]["condition"]["text"]
        }

class TestWithPytestMock:
    """Tests using pytest-mock"""

    def test_file_manager_read_config(self, mocker):
        """Test reading configuration file"""
        # mock os.path.exists
        mocker.patch('os.path.exists', return_value=True)

        # mock open and file content
        mock_open = mocker.mock_open(read_data="config_content")
        mocker.patch('builtins.open', mock_open)

        file_manager = FileManager()
        content = file_manager.read_config("/path/to/config.ini")

        assert content == "config_content"
        mock_open.assert_called_once_with("/path/to/config.ini", 'r')

    def test_file_manager_config_not_found(self, mocker):
        """Test configuration file not found"""
        # mock os.path.exists to return False
        mocker.patch('os.path.exists', return_value=False)

        file_manager = FileManager()

        with pytest.raises(FileNotFoundError, match="Config file not found"):
            file_manager.read_config("/nonexistent/config.ini")

    def test_file_manager_write_log(self, mocker):
        """Test writing log"""
        # mock datetime
        mock_datetime = mocker.patch('datetime.datetime')
        mock_datetime.now.return_value.isoformat.return_value = "2023-01-01T12:00:00"

        # mock open
        mock_open = mocker.mock_open()
        mocker.patch('builtins.open', mock_open)

        file_manager = FileManager()
        file_manager.write_log("Test message")

        # Verify file write
        mock_open.assert_called_once_with(Path("app.log"), 'a')
        mock_open().write.assert_called_once_with("[2023-01-01T12:00:00] Test message\n")

    def test_weather_service_success(self, mocker):
        """Test weather service successful response"""
        # mock requests.get
        mock_response = mocker.Mock()
        mock_response.json.return_value = {
            "location": {"city": "Beijing"},
            "current": {
                "temp_c": 25.0,
                "condition": {"text": "Sunny"}
            }
        }
        mock_response.raise_for_status.return_value = None

        mock_get = mocker.patch('requests.get', return_value=mock_response)

        weather_service = WeatherService("test_api_key")
        result = weather_service.get_current_weather("Beijing")

        # Verify result
        assert result["city"] == "Beijing"
        assert result["temperature"] == 25.0
        assert result["description"] == "Sunny"

        # Verify API call
        mock_get.assert_called_once_with(
            "https://api.weather.com/current",
            params={"city": "Beijing", "key": "test_api_key"}
        )

    def test_weather_service_api_error(self, mocker):
        """Test API error"""
        mock_response = mocker.Mock()
        mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")

        mocker.patch('requests.get', return_value=mock_response)

        weather_service = WeatherService("test_api_key")

        with pytest.raises(requests.HTTPError):
            weather_service.get_current_weather("Invalid City")

Advanced Mock Techniques

# test_advanced_mock.py
import pytest
from unittest.mock import Mock, MagicMock, PropertyMock, call
import asyncio

# Code under test
class DatabaseConnection:
    """Database connection class"""

    def __init__(self, connection_string):
        self.connection_string = connection_string
        self._connected = False

    @property
    def connected(self):
        return self._connected

    def connect(self):
        """Connect to database"""
        # Simulate connection logic
        self._connected = True
        return True

    def execute_query(self, query, params=None):
        """Execute query"""
        if not self.connected:
            raise RuntimeError("Database not connected")

        # Simulate query execution
        return {"query": query, "params": params, "rows": []}

    def close(self):
        """Close connection"""
        self._connected = False

class UserRepository:
    """User repository"""

    def __init__(self, db_connection):
        self.db = db_connection

    def find_user_by_id(self, user_id):
        """Find user by ID"""
        if not self.db.connected:
            self.db.connect()

        result = self.db.execute_query(
            "SELECT * FROM users WHERE id = ?",
            (user_id,)
        )

        if result["rows"]:
            return result["rows"][0]
        return None

    def create_user(self, user_data):
        """Create user"""
        if not self.db.connected:
            self.db.connect()

        query = "INSERT INTO users (name, email) VALUES (?, ?)"
        params = (user_data["name"], user_data["email"])

        result = self.db.execute_query(query, params)
        return {"id": 1, **user_data}  # Simulate return

class AsyncAPIClient:
    """Async API client"""

    async def fetch_data(self, url):
        """Fetch data asynchronously"""
        import aiohttp
        async with aiohttp.ClientSession() as session:
            async with session.get(url) as response:
                return await response.json()

class TestAdvancedMock:
    """Advanced Mock techniques tests"""

    def test_mock_property(self, mocker):
        """Test mock property"""
        # Create mock database connection
        mock_db = Mock(spec=DatabaseConnection)

        # mock property
        type(mock_db).connected = PropertyMock(return_value=True)

        repo = UserRepository(mock_db)

        # connected property should return True
        assert mock_db.connected is True

        # Should not call connect, because already connected
        repo.find_user_by_id(1)
        mock_db.connect.assert_not_called()

    def test_mock_side_effect_function(self):
        """Test side_effect function"""
        def mock_execute_query(query, params=None):
            """Custom mock function"""
            if "SELECT" in query and params and params[0] == 1:
                return {"rows": [{"id": 1, "name": "John", "email": "john@example.com"}]}
            elif "INSERT" in query:
                return {"rows": [], "last_insert_id": 1}
            else:
                return {"rows": []}

        mock_db = Mock()
        mock_db.connected = True
        mock_db.execute_query.side_effect = mock_execute_query

        repo = UserRepository(mock_db)

        # Test finding user
        user = repo.find_user_by_id(1)
        assert user["name"] == "John"

        # Test creating user
        new_user = repo.create_user({"name": "Jane", "email": "jane@example.com"})
        assert new_user["name"] == "Jane"

    def test_mock_call_tracking(self):
        """Test call tracking"""
        mock_db = Mock()
        mock_db.connected = False
        mock_db.execute_query.return_value = {"rows": []}

        repo = UserRepository(mock_db)

        # Execute multiple operations
        repo.find_user_by_id(1)
        repo.find_user_by_id(2)
        repo.create_user({"name": "Test", "email": "test@example.com"})

        # Verify call count
        assert mock_db.connect.call_count == 3
        assert mock_db.execute_query.call_count == 3

        # Verify call order
        expected_calls = [
            call("SELECT * FROM users WHERE id = ?", (1,)),
            call("SELECT * FROM users WHERE id = ?", (2,)),
            call("INSERT INTO users (name, email) VALUES (?, ?)", ("Test", "test@example.com"))
        ]
        mock_db.execute_query.assert_has_calls(expected_calls)

    def test_mock_context_manager(self, mocker):
        """Test mock context manager"""
        # mock database connection as context manager
        mock_db = mocker.MagicMock()
        mock_db.__enter__.return_value = mock_db
        mock_db.execute_query.return_value = {"rows": [{"id": 1, "name": "Test"}]}

        # Use context manager
        with mock_db as db:
            result = db.execute_query("SELECT * FROM users")

        # Verify context manager calls
        mock_db.__enter__.assert_called_once()
        mock_db.__exit__.assert_called_once()
        assert result["rows"][0]["name"] == "Test"

    @pytest.mark.asyncio
    async def test_async_mock(self, mocker):
        """Test async mock"""
        # mock aiohttp module
        mock_session = mocker.AsyncMock()
        mock_response = mocker.AsyncMock()
        mock_response.json.return_value = {"data": "test_data"}

        mock_session.__aenter__.return_value = mock_session
        mock_session.get.return_value.__aenter__.return_value = mock_response

        # mock aiohttp.ClientSession
        mocker.patch('aiohttp.ClientSession', return_value=mock_session)

        client = AsyncAPIClient()
        result = await client.fetch_data("https://api.example.com/data")

        assert result["data"] == "test_data"

    def test_mock_partial_methods(self):
        """Test partial method mock"""
        # Create real object, but mock specific method
        db = DatabaseConnection("test://connection")

        # mock execute_query method, but keep other methods
        with patch.object(db, 'execute_query') as mock_execute:
            mock_execute.return_value = {"rows": [{"id": 1, "name": "Mocked User"}]}

            repo = UserRepository(db)

            # connect method is real
            assert not db.connected

            user = repo.find_user_by_id(1)

            # connect was really called
            assert db.connected

            # execute_query was mocked
            assert user["name"] == "Mocked User"
            mock_execute.assert_called_once()

    def test_mock_chained_calls(self):
        """Test chained call mock"""
        mock_api = Mock()

        # Set up chained calls
        mock_api.users.get.return_value.json.return_value = {
            "id": 1,
            "name": "John"
        }

        # Simulate API call
        result = mock_api.users.get(1).json()

        assert result["name"] == "John"
        mock_api.users.get.assert_called_once_with(1)

    def test_mock_configuration(self):
        """Test mock configuration"""
        # Use spec to ensure mock object has correct interface
        mock_db = Mock(spec=DatabaseConnection)

        # This will work, because DatabaseConnection has connect method
        mock_db.connect()

        # This will raise AttributeError, because DatabaseConnection doesn't have invalid_method
        with pytest.raises(AttributeError):
            mock_db.invalid_method()

    def test_spy_pattern(self, mocker):
        """Test spy pattern"""
        # Create real object
        db = DatabaseConnection("test://connection")

        # Use spy to monitor real method
        spy_connect = mocker.spy(db, 'connect')

        # Call real method
        db.connect()

        # Verify method was called
        spy_connect.assert_called_once()

        # Real functionality still works
        assert db.connected is True

Mock Best Practices

# test_mock_best_practices.py
import pytest
from unittest.mock import Mock, patch
from contextlib import contextmanager

class TestMockBestPractices:
    """Mock best practices examples"""

    def test_mock_return_value_vs_side_effect(self):
        """return_value vs side_effect usage"""
        mock_func = Mock()

        # Use return_value to return fixed value
        mock_func.return_value = "fixed_result"
        assert mock_func() == "fixed_result"

        # Use side_effect to return different values
        mock_func.side_effect = ["result1", "result2", "result3"]
        assert mock_func() == "result1"
        assert mock_func() == "result2"
        assert mock_func() == "result3"

        # side_effect can also be a function
        mock_func.side_effect = lambda x: f"processed_{x}"
        assert mock_func("input") == "processed_input"

    def test_mock_with_spec(self):
        """Using spec to restrict mock behavior"""
        # Mock without spec can call any attribute
        loose_mock = Mock()
        loose_mock.any_method()  # Won't error

        # Mock with spec can only call methods of specified class
        class RealClass:
            def real_method(self):
                pass

        strict_mock = Mock(spec=RealClass)
        strict_mock.real_method()  # OK

        with pytest.raises(AttributeError):
            strict_mock.fake_method()  # Will error

    @contextmanager
    def temporary_mock(self, target, **kwargs):
        """Temporary mock context manager"""
        with patch(target, **kwargs) as mock:
            yield mock

    def test_custom_mock_context(self):
        """Custom mock context"""
        with self.temporary_mock('time.sleep', return_value=None) as mock_sleep:
            import time
            time.sleep(1)
            mock_sleep.assert_called_once_with(1)

    def test_mock_cleanup(self):
        """Ensure mock cleanup"""
        original_function = len

        with patch('builtins.len', return_value=42):
            assert len([1, 2, 3]) == 42  # mock is active

        # After patch ends, original function is restored
        assert len([1, 2, 3]) == 3

    def test_avoid_over_mocking(self):
        """Avoid over-mocking"""
        # Bad practice: mock too many internal details
        # Good practice: only mock external dependencies

        class Calculator:
            def add(self, a, b):
                return a + b

            def multiply(self, a, b):
                return a * b

            def complex_calculation(self, x, y):
                sum_result = self.add(x, y)
                return self.multiply(sum_result, 2)

        calc = Calculator()

        # Don't mock internal methods add and multiply
        # Test complex_calculation result directly
        result = calc.complex_calculation(3, 4)
        assert result == 14  # (3 + 4) * 2

    def test_mock_assertions(self):
        """Mock assertion best practices"""
        mock_func = Mock()

        # Call mock
        mock_func("arg1", "arg2", key="value")

        # Detailed assertion
        mock_func.assert_called_once_with("arg1", "arg2", key="value")

        # Check call count
        assert mock_func.call_count == 1

        # Check if called
        assert mock_func.called

        # Reset mock
        mock_func.reset_mock()
        assert not mock_func.called
Mock Best Practices
  1. Clear boundaries: Only mock external dependencies, not code under test
  2. Use spec: Use spec parameter to ensure mock object interface is correct
  3. Verify interactions: Not only verify return values, but also verify method calls
  4. Avoid over-mocking: Too many mocks make tests fragile
  5. Clean up mocks: Ensure mocks are properly cleaned up after tests
Common Pitfalls
  1. Mocking wrong object: Ensure you’re mocking the correct import path
  2. Over-mocking: Don’t mock internal methods of code under test
  3. Forgetting verification: After creating mock, verify its interactions
  4. State pollution: Ensure mocks don’t affect other tests

Mock is a powerful tool for isolating tests and controlling dependencies. Correct use of Mock can help write more reliable and faster tests.