Chapter 7 - Mock and Test Doubles
Haiyue
26min
Chapter 7: Mock and Test Doubles
Learning Objectives
- Understand Mock concepts and use cases
- Master the use of unittest.mock
- Learn dependency injection and test isolation
- Master the use of pytest-mock plugin
Knowledge Points
Mock Concepts and Uses
Mock (mock object) is a test double technique used for:
- Dependency isolation: Isolate code under test from external dependencies
- Behavior control: Precisely control dependency behavior and return values
- Interaction verification: Verify method call counts and parameters
- Improve speed: Avoid real network, database calls
Test Double Types
| Type | Description | Use Case |
|---|---|---|
| Dummy | Placeholder object, not used | Satisfy interface requirements |
| Fake | Simplified working implementation | In-memory database |
| Stub | Object with preset responses | Fixed return values |
| Spy | Object that records call information | Verify interactions |
| Mock | Fully controllable test double | Complex interaction testing |
Mock Tool Comparison
# unittest.mock - Built into Python
from unittest.mock import Mock, patch, MagicMock
# pytest-mock - pytest integration
pytest.MonkeyPatch, mocker fixture
# responses - HTTP request mock
import responses
Example Code
Basic Mock Usage
# test_mock_basic.py
import pytest
from unittest.mock import Mock, MagicMock, patch
# Code under test
class EmailService:
"""Email service"""
def __init__(self, smtp_client):
self.smtp_client = smtp_client
def send_email(self, to_email, subject, body):
"""Send email"""
try:
result = self.smtp_client.send(
to=to_email,
subject=subject,
body=body
)
return {"status": "success", "message_id": result.message_id}
except Exception as e:
return {"status": "error", "error": str(e)}
def send_bulk_emails(self, email_list):
"""Send bulk emails"""
results = []
for email_data in email_list:
result = self.send_email(
email_data["to"],
email_data["subject"],
email_data["body"]
)
results.append(result)
return results
class TestEmailService:
"""Email service tests"""
def test_send_email_success(self):
"""Test successful email sending"""
# Create mock object
mock_smtp_client = Mock()
mock_result = Mock()
mock_result.message_id = "msg_123456"
mock_smtp_client.send.return_value = mock_result
# Create service instance
email_service = EmailService(mock_smtp_client)
# Execute test
result = email_service.send_email(
"test@example.com",
"Test Subject",
"Test Body"
)
# Verify result
assert result["status"] == "success"
assert result["message_id"] == "msg_123456"
# Verify mock call
mock_smtp_client.send.assert_called_once_with(
to="test@example.com",
subject="Test Subject",
body="Test Body"
)
def test_send_email_failure(self):
"""Test email sending failure"""
# Create mock that throws exception
mock_smtp_client = Mock()
mock_smtp_client.send.side_effect = ConnectionError("SMTP server unavailable")
email_service = EmailService(mock_smtp_client)
result = email_service.send_email(
"test@example.com",
"Test Subject",
"Test Body"
)
# Verify error handling
assert result["status"] == "error"
assert "SMTP server unavailable" in result["error"]
def test_bulk_email_sending(self):
"""Test bulk email sending"""
mock_smtp_client = Mock()
mock_result = Mock()
mock_result.message_id = "msg_123"
mock_smtp_client.send.return_value = mock_result
email_service = EmailService(mock_smtp_client)
email_list = [
{"to": "user1@example.com", "subject": "Subject 1", "body": "Body 1"},
{"to": "user2@example.com", "subject": "Subject 2", "body": "Body 2"}
]
results = email_service.send_bulk_emails(email_list)
# Verify results
assert len(results) == 2
assert all(r["status"] == "success" for r in results)
# Verify call count
assert mock_smtp_client.send.call_count == 2
Using patch Decorator
# test_patch_decorator.py
import requests
from unittest.mock import patch, Mock
# Code under test
class APIClient:
"""API client"""
def __init__(self, base_url):
self.base_url = base_url
def get_user(self, user_id):
"""Get user information"""
response = requests.get(f"{self.base_url}/users/{user_id}")
response.raise_for_status()
return response.json()
def create_user(self, user_data):
"""Create user"""
response = requests.post(
f"{self.base_url}/users",
json=user_data
)
response.raise_for_status()
return response.json()
def get_user_with_retry(self, user_id, max_retries=3):
"""Get user with retry"""
import time
for attempt in range(max_retries):
try:
return self.get_user(user_id)
except requests.RequestException:
if attempt == max_retries - 1:
raise
time.sleep(1)
class TestAPIClient:
"""API client tests"""
@patch('requests.get')
def test_get_user_success(self, mock_get):
"""Test successful user retrieval"""
# Set mock return value
mock_response = Mock()
mock_response.json.return_value = {
"id": 1,
"name": "John Doe",
"email": "john@example.com"
}
mock_response.raise_for_status.return_value = None
mock_get.return_value = mock_response
# Execute test
client = APIClient("https://api.example.com")
user = client.get_user(1)
# Verify result
assert user["id"] == 1
assert user["name"] == "John Doe"
assert user["email"] == "john@example.com"
# Verify call
mock_get.assert_called_once_with("https://api.example.com/users/1")
@patch('requests.get')
def test_get_user_not_found(self, mock_get):
"""Test user not found"""
# Set mock to throw exception
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_get.return_value = mock_response
client = APIClient("https://api.example.com")
with pytest.raises(requests.HTTPError):
client.get_user(999)
@patch('requests.post')
def test_create_user(self, mock_post):
"""Test user creation"""
# Set mock return value
mock_response = Mock()
mock_response.json.return_value = {
"id": 2,
"name": "Jane Doe",
"email": "jane@example.com"
}
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
client = APIClient("https://api.example.com")
user_data = {"name": "Jane Doe", "email": "jane@example.com"}
result = client.create_user(user_data)
# Verify result
assert result["id"] == 2
assert result["name"] == "Jane Doe"
# Verify call parameters
mock_post.assert_called_once_with(
"https://api.example.com/users",
json=user_data
)
@patch('time.sleep')
@patch('requests.get')
def test_get_user_with_retry(self, mock_get, mock_sleep):
"""Test retry mechanism"""
# Set first two calls to fail, third to succeed
mock_response_success = Mock()
mock_response_success.json.return_value = {"id": 1, "name": "John"}
mock_response_success.raise_for_status.return_value = None
mock_get.side_effect = [
requests.RequestException("Network error"),
requests.RequestException("Timeout"),
mock_response_success
]
client = APIClient("https://api.example.com")
result = client.get_user_with_retry(1)
# Verify result
assert result["id"] == 1
assert result["name"] == "John"
# Verify retry count
assert mock_get.call_count == 3
assert mock_sleep.call_count == 2 # sleep after first two failures
pytest-mock Plugin Usage
# test_pytest_mock.py
# Requires installation: pip install pytest-mock
import requests
import os
from pathlib import Path
# Code under test
class FileManager:
"""File manager"""
def read_config(self, config_path):
"""Read configuration file"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path, 'r') as f:
return f.read()
def write_log(self, message):
"""Write log"""
log_path = Path("app.log")
with open(log_path, 'a') as f:
import datetime
timestamp = datetime.datetime.now().isoformat()
f.write(f"[{timestamp}] {message}\n")
class WeatherService:
"""Weather service"""
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.weather.com"
def get_current_weather(self, city):
"""Get current weather"""
url = f"{self.base_url}/current"
params = {
"city": city,
"key": self.api_key
}
response = requests.get(url, params=params)
response.raise_for_status()
data = response.json()
return {
"city": data["location"]["city"],
"temperature": data["current"]["temp_c"],
"description": data["current"]["condition"]["text"]
}
class TestWithPytestMock:
"""Tests using pytest-mock"""
def test_file_manager_read_config(self, mocker):
"""Test reading configuration file"""
# mock os.path.exists
mocker.patch('os.path.exists', return_value=True)
# mock open and file content
mock_open = mocker.mock_open(read_data="config_content")
mocker.patch('builtins.open', mock_open)
file_manager = FileManager()
content = file_manager.read_config("/path/to/config.ini")
assert content == "config_content"
mock_open.assert_called_once_with("/path/to/config.ini", 'r')
def test_file_manager_config_not_found(self, mocker):
"""Test configuration file not found"""
# mock os.path.exists to return False
mocker.patch('os.path.exists', return_value=False)
file_manager = FileManager()
with pytest.raises(FileNotFoundError, match="Config file not found"):
file_manager.read_config("/nonexistent/config.ini")
def test_file_manager_write_log(self, mocker):
"""Test writing log"""
# mock datetime
mock_datetime = mocker.patch('datetime.datetime')
mock_datetime.now.return_value.isoformat.return_value = "2023-01-01T12:00:00"
# mock open
mock_open = mocker.mock_open()
mocker.patch('builtins.open', mock_open)
file_manager = FileManager()
file_manager.write_log("Test message")
# Verify file write
mock_open.assert_called_once_with(Path("app.log"), 'a')
mock_open().write.assert_called_once_with("[2023-01-01T12:00:00] Test message\n")
def test_weather_service_success(self, mocker):
"""Test weather service successful response"""
# mock requests.get
mock_response = mocker.Mock()
mock_response.json.return_value = {
"location": {"city": "Beijing"},
"current": {
"temp_c": 25.0,
"condition": {"text": "Sunny"}
}
}
mock_response.raise_for_status.return_value = None
mock_get = mocker.patch('requests.get', return_value=mock_response)
weather_service = WeatherService("test_api_key")
result = weather_service.get_current_weather("Beijing")
# Verify result
assert result["city"] == "Beijing"
assert result["temperature"] == 25.0
assert result["description"] == "Sunny"
# Verify API call
mock_get.assert_called_once_with(
"https://api.weather.com/current",
params={"city": "Beijing", "key": "test_api_key"}
)
def test_weather_service_api_error(self, mocker):
"""Test API error"""
mock_response = mocker.Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
mocker.patch('requests.get', return_value=mock_response)
weather_service = WeatherService("test_api_key")
with pytest.raises(requests.HTTPError):
weather_service.get_current_weather("Invalid City")
Advanced Mock Techniques
# test_advanced_mock.py
import pytest
from unittest.mock import Mock, MagicMock, PropertyMock, call
import asyncio
# Code under test
class DatabaseConnection:
"""Database connection class"""
def __init__(self, connection_string):
self.connection_string = connection_string
self._connected = False
@property
def connected(self):
return self._connected
def connect(self):
"""Connect to database"""
# Simulate connection logic
self._connected = True
return True
def execute_query(self, query, params=None):
"""Execute query"""
if not self.connected:
raise RuntimeError("Database not connected")
# Simulate query execution
return {"query": query, "params": params, "rows": []}
def close(self):
"""Close connection"""
self._connected = False
class UserRepository:
"""User repository"""
def __init__(self, db_connection):
self.db = db_connection
def find_user_by_id(self, user_id):
"""Find user by ID"""
if not self.db.connected:
self.db.connect()
result = self.db.execute_query(
"SELECT * FROM users WHERE id = ?",
(user_id,)
)
if result["rows"]:
return result["rows"][0]
return None
def create_user(self, user_data):
"""Create user"""
if not self.db.connected:
self.db.connect()
query = "INSERT INTO users (name, email) VALUES (?, ?)"
params = (user_data["name"], user_data["email"])
result = self.db.execute_query(query, params)
return {"id": 1, **user_data} # Simulate return
class AsyncAPIClient:
"""Async API client"""
async def fetch_data(self, url):
"""Fetch data asynchronously"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.json()
class TestAdvancedMock:
"""Advanced Mock techniques tests"""
def test_mock_property(self, mocker):
"""Test mock property"""
# Create mock database connection
mock_db = Mock(spec=DatabaseConnection)
# mock property
type(mock_db).connected = PropertyMock(return_value=True)
repo = UserRepository(mock_db)
# connected property should return True
assert mock_db.connected is True
# Should not call connect, because already connected
repo.find_user_by_id(1)
mock_db.connect.assert_not_called()
def test_mock_side_effect_function(self):
"""Test side_effect function"""
def mock_execute_query(query, params=None):
"""Custom mock function"""
if "SELECT" in query and params and params[0] == 1:
return {"rows": [{"id": 1, "name": "John", "email": "john@example.com"}]}
elif "INSERT" in query:
return {"rows": [], "last_insert_id": 1}
else:
return {"rows": []}
mock_db = Mock()
mock_db.connected = True
mock_db.execute_query.side_effect = mock_execute_query
repo = UserRepository(mock_db)
# Test finding user
user = repo.find_user_by_id(1)
assert user["name"] == "John"
# Test creating user
new_user = repo.create_user({"name": "Jane", "email": "jane@example.com"})
assert new_user["name"] == "Jane"
def test_mock_call_tracking(self):
"""Test call tracking"""
mock_db = Mock()
mock_db.connected = False
mock_db.execute_query.return_value = {"rows": []}
repo = UserRepository(mock_db)
# Execute multiple operations
repo.find_user_by_id(1)
repo.find_user_by_id(2)
repo.create_user({"name": "Test", "email": "test@example.com"})
# Verify call count
assert mock_db.connect.call_count == 3
assert mock_db.execute_query.call_count == 3
# Verify call order
expected_calls = [
call("SELECT * FROM users WHERE id = ?", (1,)),
call("SELECT * FROM users WHERE id = ?", (2,)),
call("INSERT INTO users (name, email) VALUES (?, ?)", ("Test", "test@example.com"))
]
mock_db.execute_query.assert_has_calls(expected_calls)
def test_mock_context_manager(self, mocker):
"""Test mock context manager"""
# mock database connection as context manager
mock_db = mocker.MagicMock()
mock_db.__enter__.return_value = mock_db
mock_db.execute_query.return_value = {"rows": [{"id": 1, "name": "Test"}]}
# Use context manager
with mock_db as db:
result = db.execute_query("SELECT * FROM users")
# Verify context manager calls
mock_db.__enter__.assert_called_once()
mock_db.__exit__.assert_called_once()
assert result["rows"][0]["name"] == "Test"
@pytest.mark.asyncio
async def test_async_mock(self, mocker):
"""Test async mock"""
# mock aiohttp module
mock_session = mocker.AsyncMock()
mock_response = mocker.AsyncMock()
mock_response.json.return_value = {"data": "test_data"}
mock_session.__aenter__.return_value = mock_session
mock_session.get.return_value.__aenter__.return_value = mock_response
# mock aiohttp.ClientSession
mocker.patch('aiohttp.ClientSession', return_value=mock_session)
client = AsyncAPIClient()
result = await client.fetch_data("https://api.example.com/data")
assert result["data"] == "test_data"
def test_mock_partial_methods(self):
"""Test partial method mock"""
# Create real object, but mock specific method
db = DatabaseConnection("test://connection")
# mock execute_query method, but keep other methods
with patch.object(db, 'execute_query') as mock_execute:
mock_execute.return_value = {"rows": [{"id": 1, "name": "Mocked User"}]}
repo = UserRepository(db)
# connect method is real
assert not db.connected
user = repo.find_user_by_id(1)
# connect was really called
assert db.connected
# execute_query was mocked
assert user["name"] == "Mocked User"
mock_execute.assert_called_once()
def test_mock_chained_calls(self):
"""Test chained call mock"""
mock_api = Mock()
# Set up chained calls
mock_api.users.get.return_value.json.return_value = {
"id": 1,
"name": "John"
}
# Simulate API call
result = mock_api.users.get(1).json()
assert result["name"] == "John"
mock_api.users.get.assert_called_once_with(1)
def test_mock_configuration(self):
"""Test mock configuration"""
# Use spec to ensure mock object has correct interface
mock_db = Mock(spec=DatabaseConnection)
# This will work, because DatabaseConnection has connect method
mock_db.connect()
# This will raise AttributeError, because DatabaseConnection doesn't have invalid_method
with pytest.raises(AttributeError):
mock_db.invalid_method()
def test_spy_pattern(self, mocker):
"""Test spy pattern"""
# Create real object
db = DatabaseConnection("test://connection")
# Use spy to monitor real method
spy_connect = mocker.spy(db, 'connect')
# Call real method
db.connect()
# Verify method was called
spy_connect.assert_called_once()
# Real functionality still works
assert db.connected is True
Mock Best Practices
# test_mock_best_practices.py
import pytest
from unittest.mock import Mock, patch
from contextlib import contextmanager
class TestMockBestPractices:
"""Mock best practices examples"""
def test_mock_return_value_vs_side_effect(self):
"""return_value vs side_effect usage"""
mock_func = Mock()
# Use return_value to return fixed value
mock_func.return_value = "fixed_result"
assert mock_func() == "fixed_result"
# Use side_effect to return different values
mock_func.side_effect = ["result1", "result2", "result3"]
assert mock_func() == "result1"
assert mock_func() == "result2"
assert mock_func() == "result3"
# side_effect can also be a function
mock_func.side_effect = lambda x: f"processed_{x}"
assert mock_func("input") == "processed_input"
def test_mock_with_spec(self):
"""Using spec to restrict mock behavior"""
# Mock without spec can call any attribute
loose_mock = Mock()
loose_mock.any_method() # Won't error
# Mock with spec can only call methods of specified class
class RealClass:
def real_method(self):
pass
strict_mock = Mock(spec=RealClass)
strict_mock.real_method() # OK
with pytest.raises(AttributeError):
strict_mock.fake_method() # Will error
@contextmanager
def temporary_mock(self, target, **kwargs):
"""Temporary mock context manager"""
with patch(target, **kwargs) as mock:
yield mock
def test_custom_mock_context(self):
"""Custom mock context"""
with self.temporary_mock('time.sleep', return_value=None) as mock_sleep:
import time
time.sleep(1)
mock_sleep.assert_called_once_with(1)
def test_mock_cleanup(self):
"""Ensure mock cleanup"""
original_function = len
with patch('builtins.len', return_value=42):
assert len([1, 2, 3]) == 42 # mock is active
# After patch ends, original function is restored
assert len([1, 2, 3]) == 3
def test_avoid_over_mocking(self):
"""Avoid over-mocking"""
# Bad practice: mock too many internal details
# Good practice: only mock external dependencies
class Calculator:
def add(self, a, b):
return a + b
def multiply(self, a, b):
return a * b
def complex_calculation(self, x, y):
sum_result = self.add(x, y)
return self.multiply(sum_result, 2)
calc = Calculator()
# Don't mock internal methods add and multiply
# Test complex_calculation result directly
result = calc.complex_calculation(3, 4)
assert result == 14 # (3 + 4) * 2
def test_mock_assertions(self):
"""Mock assertion best practices"""
mock_func = Mock()
# Call mock
mock_func("arg1", "arg2", key="value")
# Detailed assertion
mock_func.assert_called_once_with("arg1", "arg2", key="value")
# Check call count
assert mock_func.call_count == 1
# Check if called
assert mock_func.called
# Reset mock
mock_func.reset_mock()
assert not mock_func.called
Mock Best Practices
- Clear boundaries: Only mock external dependencies, not code under test
- Use spec: Use
specparameter to ensure mock object interface is correct - Verify interactions: Not only verify return values, but also verify method calls
- Avoid over-mocking: Too many mocks make tests fragile
- Clean up mocks: Ensure mocks are properly cleaned up after tests
Common Pitfalls
- Mocking wrong object: Ensure you’re mocking the correct import path
- Over-mocking: Don’t mock internal methods of code under test
- Forgetting verification: After creating mock, verify its interactions
- State pollution: Ensure mocks don’t affect other tests
Mock is a powerful tool for isolating tests and controlling dependencies. Correct use of Mock can help write more reliable and faster tests.