cherry-ai/backend/app/services/workflow_service.py
2025-05-02 17:31:33 +08:00

155 lines
8.2 KiB
Python

# File: backend/app/services/workflow_service.py (Refactor)
# Description: 服务层,使用组件化方式执行工作流
from typing import List, Dict, Optional, Any, Tuple
from app.models.pydantic_models import NodeModel, EdgeModel, WorkflowRunResponse
from app.flow_components.base import component_registry, BaseComponent # Import registry and base
from sqlalchemy.ext.asyncio import AsyncSession
import graphlib # For topological sort
class WorkflowExecutionError(Exception): pass
class WorkflowService:
"""执行工作流的服务 (组件化)"""
async def execute_workflow(
self,
nodes: List[NodeModel],
edges: List[EdgeModel],
db: Optional[AsyncSession] = None # Pass DB if components need it
) -> WorkflowRunResponse:
print("开始执行工作流 (组件化)...")
# 1. 构建依赖图 & 拓扑排序
try:
graph: Dict[str, set[str]] = {node.id: set() for node in nodes}
node_map: Dict[str, NodeModel] = {node.id: node for node in nodes}
handle_to_node_map: Dict[str, Dict[str, str]] = {node.id: {} for node in nodes} # {node_id: {handle_id: input_name}}
# Pre-populate handle map based on component definitions
for node in nodes:
if node.type in component_registry:
component_cls = component_registry[node.type]
for input_field in component_cls.inputs:
if input_field.is_handle:
handle_to_node_map[node.id][input_field.name] = input_field.name # Map handle ID to input name
for edge in edges:
if edge.source in graph and edge.target in graph:
graph[edge.source].add(edge.target) # source depends on target? No, target depends on source
# Let's reverse: target depends on source
# graph[edge.target].add(edge.source) # This seems wrong for graphlib
# graphlib expects {node: {dependencies}}
# So, target node depends on source node
graph[edge.target].add(edge.source)
sorter = graphlib.TopologicalSorter(graph)
execution_order = list(sorter.static_order())
print(f"执行顺序: {execution_order}")
except graphlib.CycleError as e:
print(f"工作流中存在循环: {e}")
return WorkflowRunResponse(success=False, message="工作流中检测到循环,无法执行。")
except Exception as e:
print(f"构建执行图时出错: {e}")
return WorkflowRunResponse(success=False, message=f"构建执行图失败: {e}")
# 2. 执行节点
node_outputs: Dict[str, Dict[str, Any]] = {} # Store outputs {node_id: {output_handle_id: value}}
final_output_value: Any = None
final_output_node_id: Optional[str] = None
for node_id in execution_order:
node = node_map.get(node_id)
if not node:
print(f"错误: 找不到节点 {node_id}")
continue # Should not happen if graph is correct
component_cls = component_registry.get(node.type)
if not component_cls:
print(f"警告: 找不到节点类型 '{node.type}' 的后端组件,跳过节点 {node_id}")
continue
print(f"\n--- 执行节点: {node_id} (类型: {node.type}) ---")
# a. 实例化组件,传入节点数据
component_instance = component_cls(node_data=node.data.model_dump())
# b. 收集输入值
inputs_for_run: Dict[str, Any] = {}
try:
# Gather inputs from connected parent nodes
for edge in edges:
if edge.target == node_id:
source_node_id = edge.source
source_handle_id = edge.sourceHandle
target_handle_id = edge.targetHandle
if source_node_id in node_outputs and source_handle_id in node_outputs[source_node_id]:
# Map target handle ID to the correct input name for the component's run method
input_name = handle_to_node_map.get(node_id, {}).get(target_handle_id)
if input_name:
input_value = node_outputs[source_node_id][source_handle_id]
inputs_for_run[input_name] = input_value
print(f" 输入 '{input_name}' 来自 {source_node_id}.{source_handle_id} = {str(input_value)[:50]}...")
else:
print(f"警告: 找不到节点 {node_id} 的目标 Handle '{target_handle_id}' 对应的输入字段名。")
else:
# This might happen if the source node hasn't run or didn't produce the expected output
print(f"警告: 找不到来自 {source_node_id}.{source_handle_id} 的输出,无法连接到 {node_id}.{target_handle_id}")
# Check if the input is required
target_input_field = next((f for f in component_instance.inputs if f.name == handle_to_node_map.get(node_id, {}).get(target_handle_id)), None)
if target_input_field and target_input_field.required:
raise WorkflowExecutionError(f"节点 '{component_instance.display_name}' ({node_id}) 的必需输入 '{target_input_field.display_name}' 未连接或上游节点未提供输出。")
# c. 验证并合并来自 UI 的输入 (node.data)
# validate_inputs should combine inputs_for_run and node_data
resolved_inputs = component_instance.validate_inputs(inputs_for_run)
print(f" 解析后的输入: {resolved_inputs}")
# d. 执行组件的 run 方法
outputs = await component_instance.run(resolved_inputs, db)
node_outputs[node_id] = outputs # Store outputs
print(f" 节点输出: {outputs}")
# e. 检查是否为最终输出 (来自 ChatOutputNode)
if node.type == 'chatOutputNode' and 'final_output' in outputs:
final_output_value = outputs['final_output']
final_output_node_id = node_id
except ValueError as e: # Catch validation errors
print(f"节点 {node_id} 输入验证失败: {e}")
return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行错误: {e}")
except Exception as e:
print(f"节点 {node_id} 执行时发生错误: {e}")
# Log traceback e
return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行失败: {e}")
print("\n--- 工作流执行完成 ---")
if final_output_value is not None:
return WorkflowRunResponse(
success=True,
output=str(final_output_value), # Ensure output is string
output_node_id=final_output_node_id,
message="工作流执行成功"
)
else:
# Workflow finished but didn't produce output via a ChatOutputNode
print("警告: 工作流执行完毕,但未找到指定的 ChatOutputNode 或其未产生 'final_output'")
# Find the last node's output as a fallback?
last_node_id = execution_order[-1] if execution_order else None
fallback_output = node_outputs.get(last_node_id, {})
output_key = next(iter(fallback_output)) if fallback_output else None
fallback_value = fallback_output.get(output_key) if output_key else "执行完成,但无明确输出。"
return WorkflowRunResponse(
success=True, # Or False depending on requirements
output=str(fallback_value),
output_node_id=last_node_id,
message="工作流执行完成,但未找到指定的输出节点。"
)