# File: backend/app/services/workflow_service.py (Refactor) # Description: 服务层,使用组件化方式执行工作流 from typing import List, Dict, Optional, Any, Tuple from app.models.pydantic_models import NodeModel, EdgeModel, WorkflowRunResponse from app.flow_components.base import component_registry, BaseComponent # Import registry and base from sqlalchemy.ext.asyncio import AsyncSession import graphlib # For topological sort class WorkflowExecutionError(Exception): pass class WorkflowService: """执行工作流的服务 (组件化)""" async def execute_workflow( self, nodes: List[NodeModel], edges: List[EdgeModel], db: Optional[AsyncSession] = None # Pass DB if components need it ) -> WorkflowRunResponse: print("开始执行工作流 (组件化)...") # 1. 构建依赖图 & 拓扑排序 try: graph: Dict[str, set[str]] = {node.id: set() for node in nodes} node_map: Dict[str, NodeModel] = {node.id: node for node in nodes} handle_to_node_map: Dict[str, Dict[str, str]] = {node.id: {} for node in nodes} # {node_id: {handle_id: input_name}} # Pre-populate handle map based on component definitions for node in nodes: if node.type in component_registry: component_cls = component_registry[node.type] for input_field in component_cls.inputs: if input_field.is_handle: handle_to_node_map[node.id][input_field.name] = input_field.name # Map handle ID to input name for edge in edges: if edge.source in graph and edge.target in graph: graph[edge.source].add(edge.target) # source depends on target? No, target depends on source # Let's reverse: target depends on source # graph[edge.target].add(edge.source) # This seems wrong for graphlib # graphlib expects {node: {dependencies}} # So, target node depends on source node graph[edge.target].add(edge.source) sorter = graphlib.TopologicalSorter(graph) execution_order = list(sorter.static_order()) print(f"执行顺序: {execution_order}") except graphlib.CycleError as e: print(f"工作流中存在循环: {e}") return WorkflowRunResponse(success=False, message="工作流中检测到循环,无法执行。") except Exception as e: print(f"构建执行图时出错: {e}") return WorkflowRunResponse(success=False, message=f"构建执行图失败: {e}") # 2. 执行节点 node_outputs: Dict[str, Dict[str, Any]] = {} # Store outputs {node_id: {output_handle_id: value}} final_output_value: Any = None final_output_node_id: Optional[str] = None for node_id in execution_order: node = node_map.get(node_id) if not node: print(f"错误: 找不到节点 {node_id}") continue # Should not happen if graph is correct component_cls = component_registry.get(node.type) if not component_cls: print(f"警告: 找不到节点类型 '{node.type}' 的后端组件,跳过节点 {node_id}") continue print(f"\n--- 执行节点: {node_id} (类型: {node.type}) ---") # a. 实例化组件,传入节点数据 component_instance = component_cls(node_data=node.data.model_dump()) # b. 收集输入值 inputs_for_run: Dict[str, Any] = {} try: # Gather inputs from connected parent nodes for edge in edges: if edge.target == node_id: source_node_id = edge.source source_handle_id = edge.sourceHandle target_handle_id = edge.targetHandle if source_node_id in node_outputs and source_handle_id in node_outputs[source_node_id]: # Map target handle ID to the correct input name for the component's run method input_name = handle_to_node_map.get(node_id, {}).get(target_handle_id) if input_name: input_value = node_outputs[source_node_id][source_handle_id] inputs_for_run[input_name] = input_value print(f" 输入 '{input_name}' 来自 {source_node_id}.{source_handle_id} = {str(input_value)[:50]}...") else: print(f"警告: 找不到节点 {node_id} 的目标 Handle '{target_handle_id}' 对应的输入字段名。") else: # This might happen if the source node hasn't run or didn't produce the expected output print(f"警告: 找不到来自 {source_node_id}.{source_handle_id} 的输出,无法连接到 {node_id}.{target_handle_id}") # Check if the input is required target_input_field = next((f for f in component_instance.inputs if f.name == handle_to_node_map.get(node_id, {}).get(target_handle_id)), None) if target_input_field and target_input_field.required: raise WorkflowExecutionError(f"节点 '{component_instance.display_name}' ({node_id}) 的必需输入 '{target_input_field.display_name}' 未连接或上游节点未提供输出。") # c. 验证并合并来自 UI 的输入 (node.data) # validate_inputs should combine inputs_for_run and node_data resolved_inputs = component_instance.validate_inputs(inputs_for_run) print(f" 解析后的输入: {resolved_inputs}") # d. 执行组件的 run 方法 outputs = await component_instance.run(resolved_inputs, db) node_outputs[node_id] = outputs # Store outputs print(f" 节点输出: {outputs}") # e. 检查是否为最终输出 (来自 ChatOutputNode) if node.type == 'chatOutputNode' and 'final_output' in outputs: final_output_value = outputs['final_output'] final_output_node_id = node_id except ValueError as e: # Catch validation errors print(f"节点 {node_id} 输入验证失败: {e}") return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行错误: {e}") except Exception as e: print(f"节点 {node_id} 执行时发生错误: {e}") # Log traceback e return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行失败: {e}") print("\n--- 工作流执行完成 ---") if final_output_value is not None: return WorkflowRunResponse( success=True, output=str(final_output_value), # Ensure output is string output_node_id=final_output_node_id, message="工作流执行成功" ) else: # Workflow finished but didn't produce output via a ChatOutputNode print("警告: 工作流执行完毕,但未找到指定的 ChatOutputNode 或其未产生 'final_output'。") # Find the last node's output as a fallback? last_node_id = execution_order[-1] if execution_order else None fallback_output = node_outputs.get(last_node_id, {}) output_key = next(iter(fallback_output)) if fallback_output else None fallback_value = fallback_output.get(output_key) if output_key else "执行完成,但无明确输出。" return WorkflowRunResponse( success=True, # Or False depending on requirements output=str(fallback_value), output_node_id=last_node_id, message="工作流执行完成,但未找到指定的输出节点。" )