155 lines
8.2 KiB
Python
155 lines
8.2 KiB
Python
# File: backend/app/services/workflow_service.py (Refactor)
|
|
# Description: 服务层,使用组件化方式执行工作流
|
|
|
|
from typing import List, Dict, Optional, Any, Tuple
|
|
from app.models.pydantic_models import NodeModel, EdgeModel, WorkflowRunResponse
|
|
from app.flow_components.base import component_registry, BaseComponent # Import registry and base
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import graphlib # For topological sort
|
|
|
|
class WorkflowExecutionError(Exception): pass
|
|
|
|
class WorkflowService:
|
|
"""执行工作流的服务 (组件化)"""
|
|
|
|
async def execute_workflow(
|
|
self,
|
|
nodes: List[NodeModel],
|
|
edges: List[EdgeModel],
|
|
db: Optional[AsyncSession] = None # Pass DB if components need it
|
|
) -> WorkflowRunResponse:
|
|
print("开始执行工作流 (组件化)...")
|
|
|
|
# 1. 构建依赖图 & 拓扑排序
|
|
try:
|
|
graph: Dict[str, set[str]] = {node.id: set() for node in nodes}
|
|
node_map: Dict[str, NodeModel] = {node.id: node for node in nodes}
|
|
handle_to_node_map: Dict[str, Dict[str, str]] = {node.id: {} for node in nodes} # {node_id: {handle_id: input_name}}
|
|
|
|
# Pre-populate handle map based on component definitions
|
|
for node in nodes:
|
|
if node.type in component_registry:
|
|
component_cls = component_registry[node.type]
|
|
for input_field in component_cls.inputs:
|
|
if input_field.is_handle:
|
|
handle_to_node_map[node.id][input_field.name] = input_field.name # Map handle ID to input name
|
|
|
|
|
|
for edge in edges:
|
|
if edge.source in graph and edge.target in graph:
|
|
graph[edge.source].add(edge.target) # source depends on target? No, target depends on source
|
|
# Let's reverse: target depends on source
|
|
# graph[edge.target].add(edge.source) # This seems wrong for graphlib
|
|
# graphlib expects {node: {dependencies}}
|
|
# So, target node depends on source node
|
|
graph[edge.target].add(edge.source)
|
|
|
|
|
|
sorter = graphlib.TopologicalSorter(graph)
|
|
execution_order = list(sorter.static_order())
|
|
print(f"执行顺序: {execution_order}")
|
|
|
|
except graphlib.CycleError as e:
|
|
print(f"工作流中存在循环: {e}")
|
|
return WorkflowRunResponse(success=False, message="工作流中检测到循环,无法执行。")
|
|
except Exception as e:
|
|
print(f"构建执行图时出错: {e}")
|
|
return WorkflowRunResponse(success=False, message=f"构建执行图失败: {e}")
|
|
|
|
# 2. 执行节点
|
|
node_outputs: Dict[str, Dict[str, Any]] = {} # Store outputs {node_id: {output_handle_id: value}}
|
|
final_output_value: Any = None
|
|
final_output_node_id: Optional[str] = None
|
|
|
|
for node_id in execution_order:
|
|
node = node_map.get(node_id)
|
|
if not node:
|
|
print(f"错误: 找不到节点 {node_id}")
|
|
continue # Should not happen if graph is correct
|
|
|
|
component_cls = component_registry.get(node.type)
|
|
if not component_cls:
|
|
print(f"警告: 找不到节点类型 '{node.type}' 的后端组件,跳过节点 {node_id}")
|
|
continue
|
|
|
|
print(f"\n--- 执行节点: {node_id} (类型: {node.type}) ---")
|
|
|
|
# a. 实例化组件,传入节点数据
|
|
component_instance = component_cls(node_data=node.data.model_dump())
|
|
|
|
# b. 收集输入值
|
|
inputs_for_run: Dict[str, Any] = {}
|
|
try:
|
|
# Gather inputs from connected parent nodes
|
|
for edge in edges:
|
|
if edge.target == node_id:
|
|
source_node_id = edge.source
|
|
source_handle_id = edge.sourceHandle
|
|
target_handle_id = edge.targetHandle
|
|
|
|
if source_node_id in node_outputs and source_handle_id in node_outputs[source_node_id]:
|
|
# Map target handle ID to the correct input name for the component's run method
|
|
input_name = handle_to_node_map.get(node_id, {}).get(target_handle_id)
|
|
if input_name:
|
|
input_value = node_outputs[source_node_id][source_handle_id]
|
|
inputs_for_run[input_name] = input_value
|
|
print(f" 输入 '{input_name}' 来自 {source_node_id}.{source_handle_id} = {str(input_value)[:50]}...")
|
|
else:
|
|
print(f"警告: 找不到节点 {node_id} 的目标 Handle '{target_handle_id}' 对应的输入字段名。")
|
|
else:
|
|
# This might happen if the source node hasn't run or didn't produce the expected output
|
|
print(f"警告: 找不到来自 {source_node_id}.{source_handle_id} 的输出,无法连接到 {node_id}.{target_handle_id}")
|
|
# Check if the input is required
|
|
target_input_field = next((f for f in component_instance.inputs if f.name == handle_to_node_map.get(node_id, {}).get(target_handle_id)), None)
|
|
if target_input_field and target_input_field.required:
|
|
raise WorkflowExecutionError(f"节点 '{component_instance.display_name}' ({node_id}) 的必需输入 '{target_input_field.display_name}' 未连接或上游节点未提供输出。")
|
|
|
|
|
|
# c. 验证并合并来自 UI 的输入 (node.data)
|
|
# validate_inputs should combine inputs_for_run and node_data
|
|
resolved_inputs = component_instance.validate_inputs(inputs_for_run)
|
|
print(f" 解析后的输入: {resolved_inputs}")
|
|
|
|
|
|
# d. 执行组件的 run 方法
|
|
outputs = await component_instance.run(resolved_inputs, db)
|
|
node_outputs[node_id] = outputs # Store outputs
|
|
print(f" 节点输出: {outputs}")
|
|
|
|
# e. 检查是否为最终输出 (来自 ChatOutputNode)
|
|
if node.type == 'chatOutputNode' and 'final_output' in outputs:
|
|
final_output_value = outputs['final_output']
|
|
final_output_node_id = node_id
|
|
|
|
except ValueError as e: # Catch validation errors
|
|
print(f"节点 {node_id} 输入验证失败: {e}")
|
|
return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行错误: {e}")
|
|
except Exception as e:
|
|
print(f"节点 {node_id} 执行时发生错误: {e}")
|
|
# Log traceback e
|
|
return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行失败: {e}")
|
|
|
|
print("\n--- 工作流执行完成 ---")
|
|
if final_output_value is not None:
|
|
return WorkflowRunResponse(
|
|
success=True,
|
|
output=str(final_output_value), # Ensure output is string
|
|
output_node_id=final_output_node_id,
|
|
message="工作流执行成功"
|
|
)
|
|
else:
|
|
# Workflow finished but didn't produce output via a ChatOutputNode
|
|
print("警告: 工作流执行完毕,但未找到指定的 ChatOutputNode 或其未产生 'final_output'。")
|
|
# Find the last node's output as a fallback?
|
|
last_node_id = execution_order[-1] if execution_order else None
|
|
fallback_output = node_outputs.get(last_node_id, {})
|
|
output_key = next(iter(fallback_output)) if fallback_output else None
|
|
fallback_value = fallback_output.get(output_key) if output_key else "执行完成,但无明确输出。"
|
|
|
|
return WorkflowRunResponse(
|
|
success=True, # Or False depending on requirements
|
|
output=str(fallback_value),
|
|
output_node_id=last_node_id,
|
|
message="工作流执行完成,但未找到指定的输出节点。"
|
|
)
|