Chapter 07: Multi-Step Reasoning and Complex Task Decomposition

Haiyue
54min

Chapter 07: Multi-Step Reasoning and Complex Task Decomposition

Learning Objectives
  • Design multi-step reasoning DSPy programs
  • Implement task decomposition and subtask coordination
  • Build conditional execution and branching logic
  • Learn error handling and exception recovery
  • Optimize performance of complex reasoning chains

Key Concepts

1. Multi-Step Reasoning Fundamentals

Multi-step reasoning is a key technique for solving complex problems by decomposing them into multiple simple steps to improve solution accuracy.

Reasoning Chain Design Principles

import dspy
from typing import List, Dict, Any, Optional, Union, Callable
from dataclasses import dataclass
from enum import Enum
import time
import json

class ReasoningStep:
    """Base class for reasoning steps"""

    def __init__(self, step_name: str, description: str = ""):
        self.step_name = step_name
        self.description = description
        self.inputs = {}
        self.outputs = {}
        self.execution_time = 0.0
        self.success = False
        self.error_message = ""

    def execute(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Execute reasoning step"""
        start_time = time.time()

        try:
            self.inputs = self.extract_inputs(context)
            result = self.process(self.inputs)
            self.outputs = result
            self.success = True

        except Exception as e:
            self.success = False
            self.error_message = str(e)
            self.outputs = {}

        finally:
            self.execution_time = time.time() - start_time

        return self.outputs

    def extract_inputs(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Extract inputs from context"""
        # Subclasses need to implement
        raise NotImplementedError

    def process(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Processing logic"""
        # Subclasses need to implement
        raise NotImplementedError

    def get_summary(self) -> Dict[str, Any]:
        """Get step summary"""
        return {
            'step_name': self.step_name,
            'description': self.description,
            'success': self.success,
            'execution_time': self.execution_time,
            'error_message': self.error_message,
            'inputs': self.inputs,
            'outputs': self.outputs
        }

class MultiStepReasoner(dspy.Module):
    """Multi-step reasoner"""

    def __init__(self):
        super().__init__()

        self.reasoning_steps = []
        self.context = {}
        self.execution_trace = []

        # DSPy modules for different types of reasoning
        self.problem_analyzer = dspy.ChainOfThought(
            "problem -> analysis, sub_problems",
            instructions="Analyze the complexity of the problem and identify sub-problems that need to be solved."
        )

        self.step_planner = dspy.ChainOfThought(
            "problem, analysis -> reasoning, plan",
            instructions="Based on problem analysis, formulate a detailed step-by-step solution plan."
        )

        self.step_executor = dspy.ChainOfThought(
            "step_description, available_info -> reasoning, result",
            instructions="Execute a specific reasoning step, deriving results based on available information."
        )

        self.result_synthesizer = dspy.ChainOfThought(
            "original_problem, step_results -> reasoning, final_answer",
            instructions="Synthesize results from all steps to form a complete answer to the original problem."
        )

    def add_step(self, step: ReasoningStep):
        """Add reasoning step"""
        self.reasoning_steps.append(step)

    def forward(self, problem: str):
        """Execute multi-step reasoning"""

        print(f"Starting multi-step reasoning: {problem}")

        # 1. Problem analysis
        analysis_result = self.problem_analyzer(problem=problem)
        self.context['original_problem'] = problem
        self.context['analysis'] = analysis_result.analysis
        self.context['sub_problems'] = analysis_result.sub_problems

        print(f"Problem analysis: {analysis_result.analysis}")

        # 2. Create plan
        plan_result = self.step_planner(
            problem=problem,
            analysis=analysis_result.analysis
        )
        self.context['plan'] = plan_result.plan

        print(f"Execution plan: {plan_result.plan}")

        # 3. Execute reasoning steps
        step_results = []

        for i, step in enumerate(self.reasoning_steps):
            print(f"\nExecuting step {i+1}: {step.step_name}")

            # Execute step
            step_output = step.execute(self.context)

            # Update context
            self.context[f'step_{i+1}_result'] = step_output

            # Record execution trace
            step_summary = step.get_summary()
            self.execution_trace.append(step_summary)
            step_results.append(step_summary)

            if step.success:
                print(f"Step {i+1} completed")
            else:
                print(f"Step {i+1} failed: {step.error_message}")

                # Can choose whether to continue after error
                if self.should_continue_after_error(step, i):
                    print("Continuing with subsequent steps")
                    continue
                else:
                    print("Terminating execution")
                    break

        # 4. Result synthesis
        final_result = self.result_synthesizer(
            original_problem=problem,
            step_results=json.dumps(step_results, ensure_ascii=False, indent=2)
        )

        return dspy.Prediction(
            problem=problem,
            analysis=analysis_result.analysis,
            plan=plan_result.plan,
            step_results=step_results,
            execution_trace=self.execution_trace,
            final_answer=final_result.final_answer,
            reasoning=final_result.reasoning
        )

    def should_continue_after_error(self, failed_step: ReasoningStep, step_index: int) -> bool:
        """Decide whether to continue execution after error"""
        # Can implement more complex error handling logic
        # For example: some steps are optional, some are required
        return False

# Concrete reasoning step implementations
class FactExtractionStep(ReasoningStep):
    """Fact extraction step"""

    def __init__(self):
        super().__init__("fact_extraction", "Extract key facts from input")
        self.extractor = dspy.ChainOfThought(
            "text -> reasoning, facts",
            instructions="Extract key facts and information from the given text."
        )

    def extract_inputs(self, context: Dict[str, Any]) -> Dict[str, Any]:
        return {
            'text': context.get('original_problem', '') + ' ' + context.get('analysis', '')
        }

    def process(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        result = self.extractor(text=inputs['text'])
        return {
            'extracted_facts': result.facts,
            'reasoning': result.reasoning
        }

class LogicalInferenceStep(ReasoningStep):
    """Logical inference step"""

    def __init__(self):
        super().__init__("logical_inference", "Perform logical reasoning based on known facts")
        self.inferencer = dspy.ChainOfThought(
            "facts, rules -> reasoning, inferences",
            instructions="Based on given facts and rules, perform logical reasoning to derive new conclusions."
        )

    def extract_inputs(self, context: Dict[str, Any]) -> Dict[str, Any]:
        facts = context.get('step_1_result', {}).get('extracted_facts', '')
        return {
            'facts': facts,
            'rules': "Use common sense and logical rules for reasoning"
        }

    def process(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        result = self.inferencer(
            facts=inputs['facts'],
            rules=inputs['rules']
        )
        return {
            'inferences': result.inferences,
            'reasoning': result.reasoning
        }

class CalculationStep(ReasoningStep):
    """Calculation step"""

    def __init__(self):
        super().__init__("calculation", "Perform mathematical calculations")
        self.calculator = dspy.ChainOfThought(
            "problem, values -> reasoning, calculation, result",
            instructions="Identify the mathematical problem to calculate, perform the calculation and provide the result."
        )

    def extract_inputs(self, context: Dict[str, Any]) -> Dict[str, Any]:
        # Extract numerical information from previous steps
        problem = context.get('original_problem', '')
        inferences = context.get('step_2_result', {}).get('inferences', '')

        return {
            'problem': problem,
            'values': inferences
        }

    def process(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        result = self.calculator(
            problem=inputs['problem'],
            values=inputs['values']
        )
        return {
            'calculation_steps': result.calculation,
            'final_result': result.result,
            'reasoning': result.reasoning
        }

# Usage example
def demonstrate_multi_step_reasoning():
    """Demonstrate multi-step reasoning"""

    # Create multi-step reasoner
    reasoner = MultiStepReasoner()

    # Add reasoning steps
    reasoner.add_step(FactExtractionStep())
    reasoner.add_step(LogicalInferenceStep())
    reasoner.add_step(CalculationStep())

    # Test complex problem
    complex_problem = """
    Tom has $50, bought 3 pens at $12 each. Then he bought 2 books, each book costs 1/4 of his remaining money.
    How much money does Tom have left?
    """

    result = reasoner(complex_problem)

    print(f"\nReasoning Results:")
    print(f"Original Problem: {result.problem}")
    print(f"Analysis: {result.analysis}")
    print(f"Plan: {result.plan}")
    print(f"Final Answer: {result.final_answer}")

    print(f"\nExecution Trace:")
    for i, step in enumerate(result.step_results):
        print(f"  Step {i+1} ({step['step_name']}): {'Success' if step.get('success', False) else 'Failed'}")
        if step['success']:
            print(f"    Output: {step['outputs']}")
        else:
            print(f"    Error: {step['error_message']}")

    return result

# demo_result = demonstrate_multi_step_reasoning()

2. Task Decomposition and Subtask Coordination

Complex tasks require reasonable decomposition strategies and effective subtask coordination mechanisms.

class TaskDecomposer(dspy.Module):
    """Task decomposer"""

    def __init__(self):
        super().__init__()

        # Task decomposition module
        self.decomposer = dspy.ChainOfThought(
            "complex_task -> reasoning, subtasks",
            instructions="""Decompose a complex task into multiple independent subtasks:
            1. Each subtask should be relatively independent
            2. Dependencies between subtasks should be clear
            3. Each subtask has clear inputs and outputs
            4. Arrange subtasks in execution order"""
        )

        # Dependency analysis module
        self.dependency_analyzer = dspy.ChainOfThought(
            "subtasks -> reasoning, dependencies",
            instructions="Analyze dependencies between subtasks and determine execution order."
        )

        # Subtask executor
        self.subtask_executor = dspy.ChainOfThought(
            "subtask, context -> reasoning, result",
            instructions="Execute a specific subtask, producing results based on context information."
        )

    def forward(self, complex_task: str):
        """Decompose and execute complex task"""

        print(f"Starting task decomposition: {complex_task}")

        # 1. Task decomposition
        decomposition_result = self.decomposer(complex_task=complex_task)

        # Parse subtasks
        subtasks = self.parse_subtasks(decomposition_result.subtasks)

        print(f"Decomposed into {len(subtasks)} subtasks:")
        for i, subtask in enumerate(subtasks, 1):
            print(f"  {i}. {subtask['name']}: {subtask['description']}")

        # 2. Analyze dependencies
        dependency_result = self.dependency_analyzer(
            subtasks=decomposition_result.subtasks
        )

        dependencies = self.parse_dependencies(dependency_result.dependencies)

        print(f"\nDependencies: {dependencies}")

        # 3. Execute subtasks in dependency order
        execution_order = self.determine_execution_order(subtasks, dependencies)

        print(f"\nExecution order: {[task['name'] for task in execution_order]}")

        context = {'original_task': complex_task}
        subtask_results = []

        for subtask in execution_order:
            print(f"\nExecuting subtask: {subtask['name']}")

            try:
                # Execute subtask
                result = self.subtask_executor(
                    subtask=subtask['description'],
                    context=json.dumps(context, ensure_ascii=False)
                )

                # Record result
                subtask_result = {
                    'name': subtask['name'],
                    'description': subtask['description'],
                    'result': result.result,
                    'reasoning': result.reasoning,
                    'success': True
                }

                # Update context
                context[subtask['name']] = result.result

                print(f"Subtask completed: {result.result}")

            except Exception as e:
                subtask_result = {
                    'name': subtask['name'],
                    'description': subtask['description'],
                    'error': str(e),
                    'success': False
                }

                print(f"Subtask failed: {str(e)}")

            subtask_results.append(subtask_result)

        # 4. Integrate results
        final_result = self.integrate_results(complex_task, subtask_results, context)

        return dspy.Prediction(
            complex_task=complex_task,
            subtasks=subtasks,
            dependencies=dependencies,
            execution_order=[task['name'] for task in execution_order],
            subtask_results=subtask_results,
            final_result=final_result
        )

    def parse_subtasks(self, subtasks_text: str) -> List[Dict[str, str]]:
        """Parse subtasks"""
        import re

        # Simple parsing logic (may need more complex parsing in production)
        tasks = []
        lines = subtasks_text.strip().split('\n')

        for line in lines:
            # Match format like "1. Task name: Task description"
            match = re.match(r'(\d+)\.\s*([^:]+):\s*(.+)', line.strip())
            if match:
                tasks.append({
                    'id': match.group(1),
                    'name': match.group(2).strip(),
                    'description': match.group(3).strip()
                })

        return tasks

    def parse_dependencies(self, dependencies_text: str) -> Dict[str, List[str]]:
        """Parse dependencies"""
        # Simplified dependency parsing
        # Production needs more complex parsing logic
        return {}

    def determine_execution_order(self, subtasks: List[Dict], dependencies: Dict) -> List[Dict]:
        """Determine execution order"""
        # Simplified version: if no complex dependencies, execute in original order
        return subtasks

    def integrate_results(self, original_task: str, subtask_results: List[Dict], context: Dict) -> str:
        """Integrate subtask results"""

        integrator = dspy.ChainOfThought(
            "original_task, subtask_results -> reasoning, integrated_result",
            instructions="Based on execution results of each subtask, integrate a complete solution to the original task."
        )

        result = integrator(
            original_task=original_task,
            subtask_results=json.dumps(subtask_results, ensure_ascii=False, indent=2)
        )

        return result.integrated_result

class ParallelTaskExecutor:
    """Parallel task executor"""

    def __init__(self, max_workers: int = 3):
        self.max_workers = max_workers
        self.task_queue = []
        self.completed_tasks = {}
        self.failed_tasks = {}

    def add_task(self, task_id: str, task_func: Callable, dependencies: List[str] = None):
        """Add task"""
        self.task_queue.append({
            'id': task_id,
            'func': task_func,
            'dependencies': dependencies or [],
            'status': 'pending'
        })

    def execute_tasks(self) -> Dict[str, Any]:
        """Execute all tasks"""
        print(f"Starting parallel task execution, max workers: {self.max_workers}")

        while self.task_queue:
            # Find tasks ready to execute (dependencies completed)
            ready_tasks = self.get_ready_tasks()

            if not ready_tasks:
                # Check for circular dependencies or unresolvable dependencies
                if self.has_unresolvable_dependencies():
                    print("Unresolvable dependencies exist")
                    break
                else:
                    print("Waiting for dependent tasks to complete...")
                    continue

            # Execute ready tasks
            self.execute_ready_tasks(ready_tasks)

        return {
            'completed': self.completed_tasks,
            'failed': self.failed_tasks,
            'remaining': [task['id'] for task in self.task_queue]
        }

    def get_ready_tasks(self) -> List[Dict]:
        """Get tasks ready to execute"""
        ready_tasks = []

        for task in self.task_queue:
            if task['status'] == 'pending':
                # Check if dependencies are all completed
                dependencies_met = all(
                    dep in self.completed_tasks
                    for dep in task['dependencies']
                )

                if dependencies_met:
                    ready_tasks.append(task)

        return ready_tasks[:self.max_workers]  # Limit concurrency

    def execute_ready_tasks(self, ready_tasks: List[Dict]):
        """Execute ready tasks"""
        for task in ready_tasks:
            print(f"Executing task: {task['id']}")
            task['status'] = 'running'

            try:
                # Prepare dependency task results as input
                dependency_results = {
                    dep: self.completed_tasks[dep]
                    for dep in task['dependencies']
                }

                # Execute task
                result = task['func'](dependency_results)

                # Record success result
                self.completed_tasks[task['id']] = result
                self.task_queue.remove(task)

                print(f"Task completed: {task['id']}")

            except Exception as e:
                # Record failure result
                self.failed_tasks[task['id']] = str(e)
                self.task_queue.remove(task)

                print(f"Task failed: {task['id']} - {str(e)}")

    def has_unresolvable_dependencies(self) -> bool:
        """Check for unresolvable dependencies"""
        remaining_task_ids = {task['id'] for task in self.task_queue}

        for task in self.task_queue:
            for dep in task['dependencies']:
                # Dependent task is neither in completed list nor in remaining tasks
                if dep not in self.completed_tasks and dep not in remaining_task_ids:
                    return True

        return False

# Usage example
def demonstrate_task_decomposition():
    """Demonstrate task decomposition"""

    decomposer = TaskDecomposer()

    complex_task = """
    Develop a complete marketing strategy for a new coffee shop, including market research, brand positioning,
    promotion channel selection, budget allocation, and effectiveness evaluation plan.
    """

    result = decomposer(complex_task)

    print(f"Task Decomposition Results:")
    print(f"Original Task: {result.complex_task}")
    print(f"Final Result: {result.final_result}")

    return result

def demonstrate_parallel_execution():
    """Demonstrate parallel task execution"""

    executor = ParallelTaskExecutor(max_workers=2)

    # Define example task functions
    def market_research(deps):
        time.sleep(1)  # Simulate execution time
        return "Market research report: Target customers are young professionals who prefer high-quality coffee"

    def competitor_analysis(deps):
        time.sleep(1)
        return "Competitor analysis: 3 coffee shops nearby, medium pricing"

    def brand_positioning(deps):
        # Requires market research results
        market_info = deps.get('market_research', '')
        time.sleep(1)
        return f"Brand positioning: Based on {market_info}, position as premium coffee brand"

    def promotion_strategy(deps):
        # Requires brand positioning and competitor analysis results
        brand = deps.get('brand_positioning', '')
        competitor = deps.get('competitor_analysis', '')
        time.sleep(1)
        return f"Promotion strategy: Combining {brand} and {competitor}, adopt social media marketing"

    # Add tasks
    executor.add_task('market_research', market_research)
    executor.add_task('competitor_analysis', competitor_analysis)
    executor.add_task('brand_positioning', brand_positioning, ['market_research'])
    executor.add_task('promotion_strategy', promotion_strategy,
                     ['brand_positioning', 'competitor_analysis'])

    # Execute tasks
    results = executor.execute_tasks()

    print(f"\nParallel Execution Results:")
    for task_id, result in results['completed'].items():
        print(f"  {task_id}: {result}")

    return results

# demo_decomposition = demonstrate_task_decomposition()
# demo_parallel = demonstrate_parallel_execution()

3. Conditional Execution and Branching Logic

Complex reasoning often requires decision-making based on intermediate results, implementing conditional execution.

class ConditionalReasoner(dspy.Module):
    """Conditional reasoner"""

    def __init__(self):
        super().__init__()

        # Condition evaluation module
        self.condition_evaluator = dspy.ChainOfThought(
            "context, condition -> reasoning, evaluation_result",
            instructions="Evaluate whether the given condition is met, return true or false with reasoning."
        )

        # Decision making module
        self.decision_maker = dspy.ChainOfThought(
            "context, options -> reasoning, decision",
            instructions="Based on current situation, select the best decision from multiple options."
        )

    def evaluate_condition(self, context: Dict[str, Any], condition: str) -> bool:
        """Evaluate condition"""
        result = self.condition_evaluator(
            context=json.dumps(context, ensure_ascii=False),
            condition=condition
        )

        # Parse boolean result
        evaluation = result.evaluation_result.lower()
        return 'true' in evaluation or 'yes' in evaluation or 'met' in evaluation

    def make_decision(self, context: Dict[str, Any], options: List[str]) -> str:
        """Make decision"""
        options_text = '\n'.join([f"{i+1}. {opt}" for i, opt in enumerate(options)])

        result = self.decision_maker(
            context=json.dumps(context, ensure_ascii=False),
            options=options_text
        )

        return result.decision

class BranchingWorkflow:
    """Branching workflow"""

    def __init__(self):
        self.nodes = {}  # Workflow nodes
        self.edges = {}  # Connections between nodes
        self.conditional_reasoner = ConditionalReasoner()

    def add_node(self, node_id: str, node_type: str, processor: Callable,
                 conditions: List[Dict] = None):
        """Add workflow node"""
        self.nodes[node_id] = {
            'id': node_id,
            'type': node_type,  # 'process', 'condition', 'decision'
            'processor': processor,
            'conditions': conditions or [],
            'next_nodes': []
        }

    def add_edge(self, from_node: str, to_node: str, condition: str = None):
        """Add connection between nodes"""
        if from_node not in self.edges:
            self.edges[from_node] = []

        self.edges[from_node].append({
            'to': to_node,
            'condition': condition
        })

    def execute_workflow(self, start_node: str, initial_context: Dict[str, Any]) -> Dict[str, Any]:
        """Execute workflow"""
        print(f"Starting workflow execution, start node: {start_node}")

        current_node = start_node
        context = initial_context.copy()
        execution_path = []

        while current_node:
            print(f"\nCurrent node: {current_node}")

            if current_node not in self.nodes:
                print(f"Node {current_node} does not exist")
                break

            node = self.nodes[current_node]
            execution_path.append(current_node)

            # Execute node processing logic
            try:
                if node['type'] == 'process':
                    result = node['processor'](context)
                    context.update(result)
                    print(f"Processing result: {result}")

                elif node['type'] == 'condition':
                    # Condition node
                    for condition_spec in node['conditions']:
                        condition_met = self.conditional_reasoner.evaluate_condition(
                            context, condition_spec['condition']
                        )

                        print(f"Condition evaluation: {condition_spec['condition']} -> {condition_met}")

                        if condition_met:
                            current_node = condition_spec['next_node']
                            break
                    else:
                        # All conditions not met, use default path
                        current_node = node.get('default_next', None)

                    continue  # Skip below next_node selection logic

                elif node['type'] == 'decision':
                    # Decision node
                    options = [edge['to'] for edge in self.edges.get(current_node, [])]
                    if options:
                        decision = self.conditional_reasoner.make_decision(context, options)
                        print(f"Decision result: {decision}")

                        # Select next node based on decision
                        for option in options:
                            if option in decision or decision in option:
                                current_node = option
                                break
                        else:
                            current_node = options[0]  # Default to first option
                    else:
                        current_node = None

                    continue

            except Exception as e:
                print(f"Node execution failed: {str(e)}")
                break

            # Determine next node
            next_node = None

            if current_node in self.edges:
                edges = self.edges[current_node]

                if len(edges) == 1 and edges[0]['condition'] is None:
                    # Single unconditional edge
                    next_node = edges[0]['to']

                else:
                    # Multiple edges or conditional edges, need to evaluate conditions
                    for edge in edges:
                        if edge['condition'] is None:
                            next_node = edge['to']  # Default path
                        else:
                            condition_met = self.conditional_reasoner.evaluate_condition(
                                context, edge['condition']
                            )
                            if condition_met:
                                next_node = edge['to']
                                break

            current_node = next_node

        print(f"\nWorkflow execution completed")
        print(f"Execution path: {' -> '.join(execution_path)}")

        return {
            'final_context': context,
            'execution_path': execution_path,
            'success': True
        }

# Practical application example: Intelligent customer service branching logic
class IntelligentCustomerService:
    """Intelligent customer service system"""

    def __init__(self):
        self.workflow = BranchingWorkflow()
        self.setup_workflow()

    def setup_workflow(self):
        """Setup customer service workflow"""

        # Node processing functions
        def classify_inquiry(context):
            classifier = dspy.ChainOfThought(
                "customer_message -> reasoning, category",
                instructions="Classify customer inquiry: technical issue, billing issue, product inquiry, complaint/suggestion, etc."
            )

            result = classifier(customer_message=context['customer_message'])
            return {'inquiry_category': result.category, 'classification_reasoning': result.reasoning}

        def handle_technical_issue(context):
            handler = dspy.ChainOfThought(
                "technical_problem -> reasoning, solution",
                instructions="Provide solution for technical problem."
            )

            result = handler(technical_problem=context['customer_message'])
            return {'solution': result.solution, 'solution_reasoning': result.reasoning}

        def handle_billing_inquiry(context):
            handler = dspy.ChainOfThought(
                "billing_question -> reasoning, response",
                instructions="Handle billing-related questions."
            )

            result = handler(billing_question=context['customer_message'])
            return {'response': result.response, 'response_reasoning': result.reasoning}

        def escalate_to_human(context):
            return {
                'escalation': True,
                'escalation_reason': 'Complex issue requiring human handling',
                'context_for_agent': context
            }

        def generate_final_response(context):
            generator = dspy.ChainOfThought(
                "context, solution -> reasoning, customer_response",
                instructions="Based on processing results, generate final customer response."
            )

            result = generator(
                context=json.dumps(context, ensure_ascii=False),
                solution=context.get('solution', context.get('response', ''))
            )

            return {'final_response': result.customer_response}

        # Add nodes
        self.workflow.add_node('classify', 'process', classify_inquiry)
        self.workflow.add_node('technical_handler', 'process', handle_technical_issue)
        self.workflow.add_node('billing_handler', 'process', handle_billing_inquiry)
        self.workflow.add_node('escalation', 'process', escalate_to_human)
        self.workflow.add_node('final_response', 'process', generate_final_response)

        # Add conditional branches
        self.workflow.add_edge('classify', 'technical_handler', 'technical issue')
        self.workflow.add_edge('classify', 'billing_handler', 'billing issue')
        self.workflow.add_edge('classify', 'escalation', 'complaint/suggestion')

        self.workflow.add_edge('technical_handler', 'final_response')
        self.workflow.add_edge('billing_handler', 'final_response')
        # Escalation doesn't need final response

    def handle_customer_inquiry(self, customer_message: str) -> Dict[str, Any]:
        """Handle customer inquiry"""
        initial_context = {'customer_message': customer_message}
        result = self.workflow.execute_workflow('classify', initial_context)
        return result

# Usage example
def demonstrate_conditional_reasoning():
    """Demonstrate conditional reasoning"""

    # Create intelligent customer service
    customer_service = IntelligentCustomerService()

    # Test different types of inquiries
    test_inquiries = [
        "My software won't open, what should I do?",
        "Why is this month's bill $50 more than last month?",
        "Your service is terrible, I want to complain!"
    ]

    for inquiry in test_inquiries:
        print(f"\n" + "="*60)
        print(f"Customer Inquiry: {inquiry}")

        result = customer_service.handle_customer_inquiry(inquiry)

        final_context = result['final_context']
        execution_path = result['execution_path']

        print(f"Execution path: {' -> '.join(execution_path)}")

        if 'final_response' in final_context:
            print(f"Customer service response: {final_context['final_response']}")
        elif 'escalation' in final_context:
            print(f"Escalate: {final_context['escalation_reason']}")

    return result

# demo_conditional = demonstrate_conditional_reasoning()

4. Error Handling and Exception Recovery

Robust multi-step reasoning systems need comprehensive error handling mechanisms.

class RobustReasoner(dspy.Module):
    """Robust reasoner with error handling capabilities"""

    def __init__(self):
        super().__init__()

        self.error_analyzer = dspy.ChainOfThought(
            "error_info, context -> reasoning, error_type, recovery_strategy",
            instructions="Analyze error type and propose recovery strategy."
        )

        self.recovery_planner = dspy.ChainOfThought(
            "original_plan, error_info, recovery_strategy -> reasoning, revised_plan",
            instructions="Based on error information and recovery strategy, revise the original plan."
        )

        self.alternative_solver = dspy.ChainOfThought(
            "problem, failed_approach, context -> reasoning, alternative_solution",
            instructions="When original method fails, find alternative solutions."
        )

    def handle_error(self, error: Exception, context: Dict[str, Any],
                    current_step: str) -> Dict[str, Any]:
        """Handle errors during reasoning process"""

        error_info = {
            'error_type': type(error).__name__,
            'error_message': str(error),
            'failed_step': current_step,
            'context_state': context
        }

        print(f"Error handling: {error_info['error_type']} - {error_info['error_message']}")

        # Analyze error
        analysis_result = self.error_analyzer(
            error_info=json.dumps(error_info, ensure_ascii=False),
            context=json.dumps(context, ensure_ascii=False)
        )

        recovery_strategy = {
            'error_type': analysis_result.error_type,
            'recovery_strategy': analysis_result.recovery_strategy,
            'analysis_reasoning': analysis_result.reasoning
        }

        print(f"Recovery strategy: {recovery_strategy['recovery_strategy']}")

        return recovery_strategy

    def attempt_recovery(self, recovery_strategy: Dict[str, Any],
                        context: Dict[str, Any]) -> Dict[str, Any]:
        """Attempt error recovery"""

        strategy_type = recovery_strategy['recovery_strategy'].lower()

        if 'retry' in strategy_type:
            return self.retry_with_modification(context)

        elif 'alternative' in strategy_type:
            return self.find_alternative_approach(context)

        elif 'skip' in strategy_type:
            return self.skip_step_and_continue(context)

        elif 'backtrack' in strategy_type:
            return self.backtrack_and_retry(context)

        else:
            return {'recovery_action': 'no_recovery', 'success': False}

    def retry_with_modification(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Retry with modifications"""
        print("Attempting retry with modified parameters...")

        # Implement specific modification logic
        modified_context = context.copy()
        modified_context['retry_count'] = modified_context.get('retry_count', 0) + 1
        modified_context['modified'] = True

        return {
            'recovery_action': 'retry',
            'modified_context': modified_context,
            'success': True
        }

    def find_alternative_approach(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Find alternative method"""
        print("Finding alternative solution...")

        original_problem = context.get('original_problem', '')
        failed_approach = context.get('current_approach', '')

        alternative_result = self.alternative_solver(
            problem=original_problem,
            failed_approach=failed_approach,
            context=json.dumps(context, ensure_ascii=False)
        )

        return {
            'recovery_action': 'alternative',
            'alternative_solution': alternative_result.alternative_solution,
            'reasoning': alternative_result.reasoning,
            'success': True
        }

    def skip_step_and_continue(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Skip current step and continue"""
        print("Skipping current step...")

        return {
            'recovery_action': 'skip',
            'skipped_step': context.get('current_step', ''),
            'success': True
        }

    def backtrack_and_retry(self, context: Dict[str, Any]) -> Dict[str, Any]:
        """Backtrack and retry"""
        print("Backtracking to previous step...")

        # Implement backtrack logic
        previous_state = context.get('previous_states', [])

        if previous_state:
            restored_context = previous_state[-1]
            return {
                'recovery_action': 'backtrack',
                'restored_context': restored_context,
                'success': True
            }
        else:
            return {
                'recovery_action': 'backtrack',
                'success': False,
                'reason': 'no_previous_state'
            }

class ResilientMultiStepReasoner(MultiStepReasoner):
    """Multi-step reasoner with recovery capabilities"""

    def __init__(self, max_retries: int = 3):
        super().__init__()
        self.max_retries = max_retries
        self.robust_reasoner = RobustReasoner()
        self.step_history = []

    def forward(self, problem: str):
        """Execute multi-step reasoning with error recovery"""

        print(f"Starting robust multi-step reasoning: {problem}")

        self.context['original_problem'] = problem
        self.step_history = []

        # Problem analysis (with error handling)
        try:
            analysis_result = self.problem_analyzer(problem=problem)
            self.context['analysis'] = analysis_result.analysis
            self.context['sub_problems'] = analysis_result.sub_problems
        except Exception as e:
            print(f"Problem analysis failed: {e}")
            return dspy.Prediction(
                problem=problem,
                error="Problem analysis failed",
                final_answer="Unable to analyze problem"
            )

        # Execute reasoning steps (with error handling and recovery)
        step_results = []

        for i, step in enumerate(self.reasoning_steps):
            retry_count = 0
            step_success = False

            while retry_count <= self.max_retries and not step_success:
                print(f"\nExecuting step {i+1} (attempt {retry_count+1}): {step.step_name}")

                # Save current state
                self.step_history.append(self.context.copy())

                try:
                    # Execute step
                    step_output = step.execute(self.context)

                    if step.success:
                        # Step succeeded
                        self.context[f'step_{i+1}_result'] = step_output
                        step_success = True

                        step_summary = step.get_summary()
                        step_results.append(step_summary)

                        print(f"Step {i+1} completed successfully")

                    else:
                        # Step failed, attempt error handling
                        raise Exception(step.error_message)

                except Exception as e:
                    print(f"Step {i+1} failed: {e}")

                    if retry_count < self.max_retries:
                        # Attempt error recovery
                        recovery_strategy = self.robust_reasoner.handle_error(
                            e, self.context, step.step_name
                        )

                        recovery_result = self.robust_reasoner.attempt_recovery(
                            recovery_strategy, self.context
                        )

                        if recovery_result['success']:
                            print(f"Applying recovery strategy: {recovery_result['recovery_action']}")

                            if 'modified_context' in recovery_result:
                                self.context.update(recovery_result['modified_context'])
                            elif 'restored_context' in recovery_result:
                                self.context = recovery_result['restored_context']

                            retry_count += 1
                        else:
                            print(f"Unable to recover, skipping step {i+1}")
                            break
                    else:
                        print(f"Reached max retries, step {i+1} ultimately failed")

                        # Record failed step
                        failed_summary = {
                            'step_name': step.step_name,
                            'success': False,
                            'error_message': str(e),
                            'retry_count': retry_count
                        }
                        step_results.append(failed_summary)
                        break

        # Result synthesis (considering partial failures)
        try:
            final_result = self.result_synthesizer(
                original_problem=problem,
                step_results=json.dumps(step_results, ensure_ascii=False, indent=2)
            )
            final_answer = final_result.final_answer
        except Exception as e:
            print(f"Result synthesis issue: {e}")
            final_answer = "Some steps failed, unable to provide complete answer"

        return dspy.Prediction(
            problem=problem,
            step_results=step_results,
            final_answer=final_answer,
            step_history=self.step_history,
            recovery_attempts=sum(1 for result in step_results if result.get('retry_count', 0) > 0)
        )

# Usage example
def demonstrate_error_handling():
    """Demonstrate error handling and recovery"""

    # Create steps prone to failure
    class UnreliableStep(ReasoningStep):
        def __init__(self, failure_rate: float = 0.7):
            super().__init__("unreliable_step", "Step prone to failure")
            self.failure_rate = failure_rate

        def extract_inputs(self, context: Dict[str, Any]) -> Dict[str, Any]:
            return {'input': context.get('original_problem', '')}

        def process(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
            import random

            if random.random() < self.failure_rate:
                raise Exception("Simulated random failure")

            return {'output': f"Processing result: {inputs['input']}"}

    # Create robust reasoner
    reasoner = ResilientMultiStepReasoner(max_retries=2)

    # Add unreliable steps
    reasoner.add_step(UnreliableStep(failure_rate=0.6))
    reasoner.add_step(FactExtractionStep())

    # Test error handling
    test_problem = "Testing error handling and recovery mechanism"

    result = reasoner(test_problem)

    print(f"\nError Handling Test Results:")
    print(f"Problem: {result.problem}")
    print(f"Recovery attempts: {result.recovery_attempts}")
    print(f"Final answer: {result.final_answer}")

    print(f"\nStep execution status:")
    for i, step in enumerate(result.step_results):
        status = "Success" if step.get('success', False) else "Failed"
        retry_info = f" ({step.get('retry_count', 0)} retries)" if step.get('retry_count', 0) > 0 else ""
        print(f"  Step {i+1}: {step.get('step_name', 'Unknown')} - {status}{retry_info}")

    return result

# demo_error_handling = demonstrate_error_handling()

Practice Exercises

Exercise 1: Design Domain-Specific Multi-Step Reasoning

class DomainSpecificReasoner:
    """Domain-specific reasoner exercise"""

    def __init__(self, domain: str):
        self.domain = domain
        # TODO: Implement domain-specific reasoning logic

    def create_domain_steps(self):
        """Create domain-specific reasoning steps"""
        # TODO: Design reasoning steps based on domain characteristics
        pass

    def validate_domain_constraints(self, result):
        """Validate domain constraints"""
        # TODO: Implement domain-specific result validation
        pass

# Exercise tasks:
# 1. Choose a specific domain (e.g., medical diagnosis, legal reasoning, engineering design)
# 2. Design multi-step reasoning process for that domain
# 3. Implement domain-specific constraint validation

Exercise 2: Optimize Reasoning Performance

class OptimizedReasoner:
    """Optimized reasoner exercise"""

    def __init__(self):
        self.cache = {}
        self.parallel_executor = None

    def implement_caching(self):
        """Implement result caching"""
        # TODO: Implement reasoning result caching mechanism
        pass

    def enable_parallel_execution(self):
        """Enable parallel execution"""
        # TODO: Implement parallel execution of independent steps
        pass

    def optimize_step_order(self, steps):
        """Optimize step order"""
        # TODO: Optimize execution order based on dependencies
        pass

# Exercise tasks:
# 1. Implement smart caching mechanism
# 2. Identify parallelizable reasoning steps
# 3. Implement dynamic step order optimization

Best Practices

1. Reasoning Chain Design Principles

def reasoning_design_principles():
    """Reasoning chain design principles"""

    principles = {
        'Modular Design': [
            'Each reasoning step has a single responsibility',
            'Clear interfaces between steps',
            'Easy to test and debug independently'
        ],

        'Error Handling': [
            'Anticipate possible failure points',
            'Design graceful degradation strategies',
            'Save sufficient recovery information'
        ],

        'Performance Optimization': [
            'Identify parallelizable steps',
            'Implement result caching',
            'Avoid unnecessary repeated computations'
        ],

        'Explainability': [
            'Record detailed reasoning process',
            'Provide explanations for intermediate results',
            'Support reasoning path visualization'
        ]
    }

    return principles

class ReasoningPatternLibrary:
    """Reasoning pattern library"""

    @staticmethod
    def sequential_reasoning():
        """Sequential reasoning pattern"""
        return """
        Use case: Steps have strict dependencies
        Characteristics: Simple and intuitive, easy to understand and debug
        Note: Requires good error handling mechanism
        """

    @staticmethod
    def parallel_reasoning():
        """Parallel reasoning pattern"""
        return """
        Use case: Multiple independent sub-problems
        Characteristics: Improves execution efficiency
        Note: Consider resource competition and result synchronization
        """

    @staticmethod
    def iterative_reasoning():
        """Iterative reasoning pattern"""
        return """
        Use case: Problems requiring gradual refinement
        Characteristics: Supports progressive solution
        Note: Design appropriate convergence conditions
        """

    @staticmethod
    def hierarchical_reasoning():
        """Hierarchical reasoning pattern"""
        return """
        Use case: Complex problems with hierarchical structure
        Characteristics: Supports reasoning at different abstraction levels
        Note: Properly design information transfer between levels
        """

2. Performance Monitoring and Debugging

class ReasoningProfiler:
    """Reasoning performance profiler"""

    def __init__(self):
        self.performance_data = {}
        self.debugging_info = {}

    def profile_reasoning_step(self, step_func):
        """Reasoning step performance profiling decorator"""
        import functools
        import time
        import tracemalloc

        @functools.wraps(step_func)
        def wrapper(*args, **kwargs):
            step_name = getattr(step_func, '__name__', 'unknown')

            # Start performance monitoring
            start_time = time.time()
            tracemalloc.start()

            try:
                result = step_func(*args, **kwargs)
                success = True
                error = None
            except Exception as e:
                result = None
                success = False
                error = str(e)
            finally:
                # Record performance data
                end_time = time.time()
                current, peak = tracemalloc.get_traced_memory()
                tracemalloc.stop()

                self.performance_data[step_name] = {
                    'execution_time': end_time - start_time,
                    'memory_current': current,
                    'memory_peak': peak,
                    'success': success,
                    'error': error
                }

            return result

        return wrapper

    def generate_performance_report(self) -> str:
        """Generate performance report"""
        if not self.performance_data:
            return "No performance data"

        report = ["Reasoning Performance Analysis Report", "=" * 40]

        total_time = sum(data['execution_time'] for data in self.performance_data.values())
        total_memory = sum(data['memory_peak'] for data in self.performance_data.values())

        report.append(f"Total execution time: {total_time:.3f}s")
        report.append(f"Total memory usage: {total_memory / 1024 / 1024:.2f}MB")

        report.append("\nStep Details:")
        for step_name, data in self.performance_data.items():
            status = "Success" if data['success'] else "Failed"
            report.append(f"  {step_name} {status}")
            report.append(f"    Time: {data['execution_time']:.3f}s")
            report.append(f"    Memory: {data['memory_peak'] / 1024 / 1024:.2f}MB")

            if not data['success']:
                report.append(f"    Error: {data['error']}")

        return "\n".join(report)

# Reasoning system monitoring
class ReasoningMonitor:
    """Reasoning system monitor"""

    def __init__(self):
        self.metrics = {
            'total_requests': 0,
            'successful_requests': 0,
            'failed_requests': 0,
            'average_response_time': 0.0,
            'step_failure_rates': {}
        }

    def record_request(self, success: bool, response_time: float, step_failures: Dict[str, bool]):
        """Record request metrics"""
        self.metrics['total_requests'] += 1

        if success:
            self.metrics['successful_requests'] += 1
        else:
            self.metrics['failed_requests'] += 1

        # Update average response time
        current_avg = self.metrics['average_response_time']
        total_requests = self.metrics['total_requests']

        self.metrics['average_response_time'] = (
            (current_avg * (total_requests - 1) + response_time) / total_requests
        )

        # Update step failure rates
        for step_name, failed in step_failures.items():
            if step_name not in self.metrics['step_failure_rates']:
                self.metrics['step_failure_rates'][step_name] = {'total': 0, 'failures': 0}

            self.metrics['step_failure_rates'][step_name]['total'] += 1
            if failed:
                self.metrics['step_failure_rates'][step_name]['failures'] += 1

    def get_health_status(self) -> Dict[str, Any]:
        """Get system health status"""
        if self.metrics['total_requests'] == 0:
            return {'status': 'unknown', 'reason': 'no_data'}

        success_rate = self.metrics['successful_requests'] / self.metrics['total_requests']

        if success_rate >= 0.95:
            status = 'healthy'
        elif success_rate >= 0.80:
            status = 'warning'
        else:
            status = 'critical'

        return {
            'status': status,
            'success_rate': success_rate,
            'average_response_time': self.metrics['average_response_time'],
            'total_requests': self.metrics['total_requests']
        }

Through this chapter, you should have mastered how to build complex multi-step reasoning systems using DSPy. These techniques can help you handle complex problems requiring decomposition and coordination, building more intelligent and robust AI applications.