add run workflow

This commit is contained in:
adrian 2025-05-02 17:31:33 +08:00
parent 847b3e96c8
commit c9e8b0c123
17 changed files with 1037 additions and 128 deletions

View File

@ -2,7 +2,7 @@
# Description: 聚合 v1 版本的所有 API 路由 # Description: 聚合 v1 版本的所有 API 路由
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1.endpoints import chat, assistants, sessions, messages # Import messages router from app.api.v1.endpoints import chat, assistants, sessions, messages, workflow # Import messages router
api_router = APIRouter() api_router = APIRouter()
@ -10,3 +10,4 @@ api_router.include_router(chat.router, prefix="/chat", tags=["Chat"])
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"]) api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router
api_router.include_router(workflow.router, prefix="/workflow", tags=["Workflow"]) # Add messages router

View File

@ -0,0 +1,22 @@
# File: backend/app/api/v1/endpoints/workflow.py (No significant changes needed)
# Description: 工作流相关的 API 路由 (uses the refactored service)
# ... (保持不变, 确保调用 refactored WorkflowService) ...
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, status
from app.models.pydantic_models import WorkflowRunRequest, WorkflowRunResponse
from app.services.workflow_service import WorkflowService # Import the refactored class
from app.db.database import get_db_session # Import if service needs DB
from sqlalchemy.ext.asyncio import AsyncSession
router = APIRouter()
workflow_service = WorkflowService()
@router.post("/run", response_model=WorkflowRunResponse)
async def run_workflow(
request: WorkflowRunRequest,
db: Optional[AsyncSession] = Depends(get_db_session) # Make DB optional or required based on component needs
):
print(f"收到运行工作流请求: {len(request.nodes)} 个节点, {len(request.edges)} 条边")
result = await workflow_service.execute_workflow(request.nodes, request.edges, db)
# No need to raise HTTPException here if service returns success=False
return result

View File

@ -0,0 +1,86 @@
# File: backend/app/flow_components/base.py (New)
# Description: Base class for all workflow node components
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Type, ClassVar
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession # Import if components need DB access
# --- Input/Output Field Definitions ---
# These classes help define the expected inputs and outputs for handles/UI fields
class InputField(BaseModel):
name: str # Internal name/key, matches handleId or data key
display_name: str # User-friendly name for UI
info: Optional[str] = None # Tooltip/description
field_type: str # e.g., 'str', 'int', 'float', 'bool', 'dict', 'code', 'prompt', 'llm', 'message'
required: bool = True
value: Any = None # Default value for UI fields
is_handle: bool = False # True if this input comes from a Handle connection
# Add more metadata as needed (e.g., options for dropdowns, range for sliders)
options: Optional[List[str]] = None
range_spec: Optional[Dict[str, float]] = None # e.g., {'min': 0, 'max': 1, 'step': 0.1}
class OutputField(BaseModel):
name: str # Internal name/key, matches handleId
display_name: str
field_type: str # Data type of the output handle
info: Optional[str] = None
# --- Base Component Class ---
class BaseComponent(ABC, BaseModel):
# Class variables for metadata (can be overridden by subclasses)
display_name: ClassVar[str] = "Base Component"
description: ClassVar[str] = "A base component for workflow nodes."
icon: ClassVar[Optional[str]] = None # Icon name (e.g., from Lucide)
name: ClassVar[str] # Unique internal name/type identifier (matches frontend node type)
# Instance variable to hold node-specific data from frontend
node_data: Dict[str, Any] = {}
# Class variables defining inputs and outputs
inputs: ClassVar[List[InputField]] = []
outputs: ClassVar[List[OutputField]] = []
class Config:
arbitrary_types_allowed = True # Allow complex types like AsyncSession if needed
@abstractmethod
async def run(
self,
inputs: Dict[str, Any], # Resolved inputs from UI data and connected nodes
db: Optional[AsyncSession] = None # Pass DB session if needed
) -> Dict[str, Any]: # Return dictionary of output values keyed by output name
"""Executes the component's logic."""
pass
def validate_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Validates if required inputs are present."""
resolved_inputs = {}
missing = []
for field in self.inputs:
# Combine node_data (UI config) and resolved inputs from connections
value = inputs.get(field.name, self.node_data.get(field.name, field.value))
if field.required and value is None:
missing.append(field.display_name)
# TODO: Add type validation based on field.field_type
resolved_inputs[field.name] = value
if missing:
raise ValueError(f"节点 '{self.display_name}' 缺少必需的输入: {', '.join(missing)}")
return resolved_inputs
# --- Component Registry ---
# Simple dictionary to map node type strings to component classes
component_registry: Dict[str, Type[BaseComponent]] = {}
def register_component(cls: Type[BaseComponent]):
"""Decorator to register component classes."""
if not hasattr(cls, 'name') or not cls.name:
raise ValueError(f"Component class {cls.__name__} must have a 'name' attribute.")
if cls.name in component_registry:
print(f"警告: 组件 '{cls.name}' 被重复注册。")
component_registry[cls.name] = cls
print(f"已注册组件: {cls.name}")
return cls

View File

@ -0,0 +1,27 @@
# File: backend/app/flow_components/chat_input.py (New)
# Description: Backend component for ChatInputNode
from .base import BaseComponent, InputField, OutputField, register_component
from typing import ClassVar, Dict, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
@register_component
class ChatInputNodeComponent(BaseComponent):
name: ClassVar[str] = "chatInputNode" # Matches frontend type
display_name: ClassVar[str] = "Chat Input"
description: ClassVar[str] = "从 Playground 获取聊天输入。"
icon: ClassVar[str] = "MessageCircleQuestion"
inputs: ClassVar[List[InputField]] = [
InputField(name="text", display_name="Text", field_type="str", required=False, is_handle=False, info="用户输入的文本或默认文本。"),
]
outputs: ClassVar[List[OutputField]] = [
OutputField(name="message-output", display_name="Message", field_type="message", info="输出的聊天消息。") # Use 'message' type
]
async def run(self, inputs: Dict[str, Any], db: Optional[AsyncSession] = None) -> Dict[str, Any]:
# Inputs are already validated and resolved by the base class/executor
text_input = inputs.get("text", self.node_data.get("text", "")) # Get text from UI data or default
print(f"ChatInputNode ({self.node_data.get('id', 'N/A')}): 输出文本 '{text_input}'")
# Output format should match the defined output field name
return {"message-output": text_input} # Output the text

View File

@ -0,0 +1,38 @@
# File: backend/app/flow_components/chat_output.py (New)
# Description: Backend component for ChatOutputNode
from .base import BaseComponent, InputField, OutputField, register_component
from typing import ClassVar, Dict, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
@register_component
class ChatOutputNodeComponent(BaseComponent):
name: ClassVar[str] = "chatOutputNode"
display_name: ClassVar[str] = "Chat Output"
description: ClassVar[str] = "在 Playground 显示聊天消息。"
icon: ClassVar[str] = "MessageCircleReply"
inputs: ClassVar[List[InputField]] = [
InputField(name="message-input", display_name="Message", field_type="message", required=True, is_handle=True, info="连接要显示的消息。"),
InputField(name="displayText", display_name="Text", field_type="str", required=False, is_handle=False, info="(可选)覆盖显示的文本。"),
]
outputs: ClassVar[List[OutputField]] = [
# This node typically doesn't output further, but could pass through
# OutputField(name="output", display_name="Output", field_type="message")
]
async def run(self, inputs: Dict[str, Any], db: Optional[AsyncSession] = None) -> Dict[str, Any]:
message_input = inputs.get("message-input")
display_override = inputs.get("displayText", self.node_data.get("displayText")) # Check UI data too
if message_input is None:
raise ValueError("ChatOutputNode 未收到输入消息。")
# Determine what to "output" (in this context, what the workflow considers the result)
final_text = display_override if display_override else str(message_input)
print(f"ChatOutputNode ({self.node_data.get('id', 'N/A')}): 最终显示 '{final_text[:50]}...'")
# Since this is often a terminal node for execution, return the processed input
# The executor will decide how to handle this final output
return {"final_output": final_text} # Use a consistent key like 'final_output'

View File

@ -0,0 +1,77 @@
# File: backend/app/flow_components/llm_node.py (New)
# Description: Backend component for LLMNode
from .base import BaseComponent, InputField, OutputField, register_component
from typing import ClassVar, Dict, Any, Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.chat_service import ChatService # Assuming ChatService can be used or adapted
from app.core.config import OPENAI_API_KEY
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
# Instantiate or get ChatService instance
# This might need better dependency injection in a real app
chat_service_instance = ChatService(default_api_key=OPENAI_API_KEY)
@register_component
class LLMNodeComponent(BaseComponent):
name: ClassVar[str] = "llmNode"
display_name: ClassVar[str] = "LLM 调用"
description: ClassVar[str] = "使用大语言模型生成文本。"
icon: ClassVar[str] = "BrainCircuit"
inputs: ClassVar[List[InputField]] = [
InputField(name="input-text", display_name="输入", field_type="message", required=True, is_handle=True, info="连接输入的文本或消息。"),
InputField(name="systemPrompt", display_name="系统提示", field_type="str", required=True, is_handle=False, info="定义助手的角色和行为。"),
InputField(name="model", display_name="模型名称", field_type="str", required=True, is_handle=False, info="要使用的 LLM 模型。"), # Add options if needed
InputField(name="temperature", display_name="温度", field_type="float", required=True, is_handle=False, value=0.7, range_spec={'min': 0, 'max': 1, 'step': 0.1}),
InputField(name="apiKey", display_name="API Key", field_type="secret", required=False, is_handle=False, info="(不安全)覆盖默认 API Key。"),
# Add other parameters like max_tokens, etc. as InputFields
]
outputs: ClassVar[List[OutputField]] = [
OutputField(name="output-message", display_name="Message", field_type="message", info="LLM 生成的消息。")
]
async def run(self, inputs: Dict[str, Any], db: Optional[AsyncSession] = None) -> Dict[str, Any]:
prompt_input = inputs.get("input-text")
system_prompt = inputs.get("systemPrompt")
model = inputs.get("model")
temperature = inputs.get("temperature")
# api_key = inputs.get("apiKey") # Handle API key securely if used
if not prompt_input or not system_prompt or not model or temperature is None:
raise ValueError("LLMNode 配置或输入不完整。")
print(f"LLMNode ({self.node_data.get('id', 'N/A')}): 运行模型 '{model}' (Temp: {temperature})")
print(f" System Prompt: {system_prompt[:50]}...")
print(f" Input Prompt: {prompt_input[:50]}...")
# --- Adapt ChatService or LangChain call ---
# This simplified call assumes a method that takes direct inputs
# In reality, you might build a small LangChain chain here
try:
# Construct messages for a more robust call
messages: List[BaseMessage] = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
# Assume input-text provides the user message content
messages.append(HumanMessage(content=str(prompt_input))) # Ensure it's a string
# Simplified call - needs adaptation based on ChatService structure
# Maybe ChatService needs a method like:
# async def invoke_llm(self, messages: List[BaseMessage], model_name: str, temperature: float, ...) -> str:
# result = await chat_service_instance.invoke_llm(messages, model, temperature)
# --- Temporary Simulation ---
await asyncio.sleep(1)
result = f"AI回复(模拟): 处理了 '{str(prompt_input)[:20]}...'"
# --- End Simulation ---
print(f"LLMNode Output: {result[:50]}...")
return {"output-message": result}
except Exception as e:
print(f"LLMNode 执行失败: {e}")
raise # Re-raise the exception for the executor to handle
# Need asyncio for simulation
import asyncio

View File

@ -7,6 +7,7 @@ from app.api.v1.api import api_router as api_router_v1
import app.core.config # Ensure config is loaded import app.core.config # Ensure config is loaded
from app.db.database import create_db_and_tables # Import table creation function from app.db.database import create_db_and_tables # Import table creation function
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from app.flow_components import base, chat_input, llm_node, chat_output # Add other component modules here
# --- Lifespan context manager for startup/shutdown events --- # --- Lifespan context manager for startup/shutdown events ---
@asynccontextmanager @asynccontextmanager
@ -17,6 +18,7 @@ async def lifespan(app: FastAPI):
print("数据库表已检查/创建。") print("数据库表已检查/创建。")
# You can add the default assistant creation here if needed, # You can add the default assistant creation here if needed,
# but doing it in the service/model definition might be simpler for defaults. # but doing it in the service/model definition might be simpler for defaults.
print(f"已注册组件: {list(base.component_registry.keys())}")
yield yield
# Shutdown actions # Shutdown actions
print("应用程序关闭中...") print("应用程序关闭中...")

View File

@ -2,7 +2,7 @@
# Description: Pydantic 模型定义 API 数据结构 # Description: Pydantic 模型定义 API 数据结构
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional, List from typing import Dict, Optional, List
import uuid import uuid
from datetime import datetime # Use datetime directly from datetime import datetime # Use datetime directly
@ -90,4 +90,43 @@ class MessageRead(MessageBase):
created_at: datetime created_at: datetime
class Config: class Config:
from_attributes = True from_attributes = True
# --- Workflow Node/Edge Models (for API request/response) ---
# Mirrors React Flow structure loosely
class NodeData(BaseModel):
# Define common fields or use Dict[str, Any]
label: Optional[str] = None
text: Optional[str] = None # For ChatInput
displayText: Optional[str] = None # For ChatOutput
model: Optional[str] = None # For LLMNode
temperature: Optional[float] = None # For LLMNode
systemPrompt: Optional[str] = None # For LLMNode
# Add other potential data fields from your nodes
# Use Extra.allow for flexibility if needed:
# class Config:
# extra = 'allow'
class NodeModel(BaseModel):
id: str
type: str # e.g., 'chatInputNode', 'llmNode'
position: Dict[str, float] # { x: number, y: number }
data: NodeData # Use the specific data model
class EdgeModel(BaseModel):
id: str
source: str
target: str
sourceHandle: Optional[str] = None
targetHandle: Optional[str] = None
# --- Workflow Execution Models ---
class WorkflowRunRequest(BaseModel):
nodes: List[NodeModel]
edges: List[EdgeModel]
class WorkflowRunResponse(BaseModel):
success: bool
message: Optional[str] = None
output: Optional[str] = None # The final output text
output_node_id: Optional[str] = None # ID of the node that produced the output

View File

@ -0,0 +1,154 @@
# File: backend/app/services/workflow_service.py (Refactor)
# Description: 服务层,使用组件化方式执行工作流
from typing import List, Dict, Optional, Any, Tuple
from app.models.pydantic_models import NodeModel, EdgeModel, WorkflowRunResponse
from app.flow_components.base import component_registry, BaseComponent # Import registry and base
from sqlalchemy.ext.asyncio import AsyncSession
import graphlib # For topological sort
class WorkflowExecutionError(Exception): pass
class WorkflowService:
"""执行工作流的服务 (组件化)"""
async def execute_workflow(
self,
nodes: List[NodeModel],
edges: List[EdgeModel],
db: Optional[AsyncSession] = None # Pass DB if components need it
) -> WorkflowRunResponse:
print("开始执行工作流 (组件化)...")
# 1. 构建依赖图 & 拓扑排序
try:
graph: Dict[str, set[str]] = {node.id: set() for node in nodes}
node_map: Dict[str, NodeModel] = {node.id: node for node in nodes}
handle_to_node_map: Dict[str, Dict[str, str]] = {node.id: {} for node in nodes} # {node_id: {handle_id: input_name}}
# Pre-populate handle map based on component definitions
for node in nodes:
if node.type in component_registry:
component_cls = component_registry[node.type]
for input_field in component_cls.inputs:
if input_field.is_handle:
handle_to_node_map[node.id][input_field.name] = input_field.name # Map handle ID to input name
for edge in edges:
if edge.source in graph and edge.target in graph:
graph[edge.source].add(edge.target) # source depends on target? No, target depends on source
# Let's reverse: target depends on source
# graph[edge.target].add(edge.source) # This seems wrong for graphlib
# graphlib expects {node: {dependencies}}
# So, target node depends on source node
graph[edge.target].add(edge.source)
sorter = graphlib.TopologicalSorter(graph)
execution_order = list(sorter.static_order())
print(f"执行顺序: {execution_order}")
except graphlib.CycleError as e:
print(f"工作流中存在循环: {e}")
return WorkflowRunResponse(success=False, message="工作流中检测到循环,无法执行。")
except Exception as e:
print(f"构建执行图时出错: {e}")
return WorkflowRunResponse(success=False, message=f"构建执行图失败: {e}")
# 2. 执行节点
node_outputs: Dict[str, Dict[str, Any]] = {} # Store outputs {node_id: {output_handle_id: value}}
final_output_value: Any = None
final_output_node_id: Optional[str] = None
for node_id in execution_order:
node = node_map.get(node_id)
if not node:
print(f"错误: 找不到节点 {node_id}")
continue # Should not happen if graph is correct
component_cls = component_registry.get(node.type)
if not component_cls:
print(f"警告: 找不到节点类型 '{node.type}' 的后端组件,跳过节点 {node_id}")
continue
print(f"\n--- 执行节点: {node_id} (类型: {node.type}) ---")
# a. 实例化组件,传入节点数据
component_instance = component_cls(node_data=node.data.model_dump())
# b. 收集输入值
inputs_for_run: Dict[str, Any] = {}
try:
# Gather inputs from connected parent nodes
for edge in edges:
if edge.target == node_id:
source_node_id = edge.source
source_handle_id = edge.sourceHandle
target_handle_id = edge.targetHandle
if source_node_id in node_outputs and source_handle_id in node_outputs[source_node_id]:
# Map target handle ID to the correct input name for the component's run method
input_name = handle_to_node_map.get(node_id, {}).get(target_handle_id)
if input_name:
input_value = node_outputs[source_node_id][source_handle_id]
inputs_for_run[input_name] = input_value
print(f" 输入 '{input_name}' 来自 {source_node_id}.{source_handle_id} = {str(input_value)[:50]}...")
else:
print(f"警告: 找不到节点 {node_id} 的目标 Handle '{target_handle_id}' 对应的输入字段名。")
else:
# This might happen if the source node hasn't run or didn't produce the expected output
print(f"警告: 找不到来自 {source_node_id}.{source_handle_id} 的输出,无法连接到 {node_id}.{target_handle_id}")
# Check if the input is required
target_input_field = next((f for f in component_instance.inputs if f.name == handle_to_node_map.get(node_id, {}).get(target_handle_id)), None)
if target_input_field and target_input_field.required:
raise WorkflowExecutionError(f"节点 '{component_instance.display_name}' ({node_id}) 的必需输入 '{target_input_field.display_name}' 未连接或上游节点未提供输出。")
# c. 验证并合并来自 UI 的输入 (node.data)
# validate_inputs should combine inputs_for_run and node_data
resolved_inputs = component_instance.validate_inputs(inputs_for_run)
print(f" 解析后的输入: {resolved_inputs}")
# d. 执行组件的 run 方法
outputs = await component_instance.run(resolved_inputs, db)
node_outputs[node_id] = outputs # Store outputs
print(f" 节点输出: {outputs}")
# e. 检查是否为最终输出 (来自 ChatOutputNode)
if node.type == 'chatOutputNode' and 'final_output' in outputs:
final_output_value = outputs['final_output']
final_output_node_id = node_id
except ValueError as e: # Catch validation errors
print(f"节点 {node_id} 输入验证失败: {e}")
return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行错误: {e}")
except Exception as e:
print(f"节点 {node_id} 执行时发生错误: {e}")
# Log traceback e
return WorkflowRunResponse(success=False, message=f"节点 '{component_instance.display_name}' ({node_id}) 执行失败: {e}")
print("\n--- 工作流执行完成 ---")
if final_output_value is not None:
return WorkflowRunResponse(
success=True,
output=str(final_output_value), # Ensure output is string
output_node_id=final_output_node_id,
message="工作流执行成功"
)
else:
# Workflow finished but didn't produce output via a ChatOutputNode
print("警告: 工作流执行完毕,但未找到指定的 ChatOutputNode 或其未产生 'final_output'")
# Find the last node's output as a fallback?
last_node_id = execution_order[-1] if execution_order else None
fallback_output = node_outputs.get(last_node_id, {})
output_key = next(iter(fallback_output)) if fallback_output else None
fallback_value = fallback_output.get(output_key) if output_key else "执行完成,但无明确输出。"
return WorkflowRunResponse(
success=True, # Or False depending on requirements
output=str(fallback_value),
output_node_id=last_node_id,
message="工作流执行完成,但未找到指定的输出节点。"
)

View File

@ -0,0 +1,69 @@
// File: frontend/app/workflow/components/ChatInputNode.tsx (新建)
// Description: 自定义 ChatInput 节点
import React, { memo, useCallback, ChangeEvent } from 'react';
import { Handle, Position, useReactFlow, Node } from 'reactflow';
import { MessageCircleQuestion } from 'lucide-react'; // 使用合适的图标
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
import type { ChatInputNodeData, CustomNodeProps } from './types';
const ChatInputNodeComponent = ({ id, data, isConnectable }: CustomNodeProps<ChatInputNodeData>) => {
const { setNodes } = useReactFlow<ChatInputNodeData>();
const handleTextChange = useCallback((event: ChangeEvent<HTMLTextAreaElement>) => {
const newText = event.target.value;
setNodes((nds: Node<ChatInputNodeData>[]) =>
nds.map((node) => {
if (node.id === id) {
return { ...node, data: { ...node.data, text: newText } };
}
return node;
})
);
}, [id, setNodes]);
return (
<div className="react-flow__node-genericNode nopan bg-gray-50 dark:bg-gray-800 border border-gray-300 dark:border-gray-700 rounded-lg shadow-lg w-96"> {/* 调整宽度 */}
{/* 节点头部 */}
<div className="bg-gray-100 dark:bg-gray-700 p-3 border-b border-gray-300 dark:border-gray-600">
<div className="flex items-center gap-2">
<MessageCircleQuestion size={18} className="text-gray-700 dark:text-gray-300" />
<strong className="text-gray-800 dark:text-gray-200">Chat Input</strong>
</div>
<p className="text-xs text-gray-600 dark:text-gray-400 mt-1"> Playground </p>
</div>
{/* 节点内容 */}
<div className="p-4 space-y-2">
<div className="nodrag">
<Label htmlFor={`chat-input-text-${id}`} className="text-xs font-semibold text-gray-500 dark:text-gray-400">Text</Label>
<Textarea
id={`chat-input-text-${id}`}
name="text"
value={data.text || ''}
onChange={handleTextChange}
placeholder="输入聊天内容..."
className="mt-1 text-sm min-h-[80px] bg-white dark:bg-gray-700"
rows={3}
/>
</div>
</div>
{/* 输出 Handle */}
<div className="relative p-2 border-t border-gray-300 dark:border-gray-600">
<Label className="text-xs font-semibold text-gray-500 dark:text-gray-400 block text-right pr-7">Message</Label>
<Handle
type="source"
position={Position.Right}
id="message-output"
isConnectable={isConnectable}
className="!w-3 !h-3 !bg-purple-500 top-1/2" // 使用紫色 Handle
/>
</div>
</div>
);
};
ChatInputNodeComponent.displayName = 'ChatInputNode';
export const ChatInputNode = memo(ChatInputNodeComponent);

View File

@ -0,0 +1,84 @@
// File: frontend/app/workflow/components/ChatOutputNode.tsx (新建)
// Description: 自定义 ChatOutput 节点
import React, { memo, useCallback, ChangeEvent } from 'react';
import { Handle, Position, useReactFlow, Node } from 'reactflow';
import { MessageCircleReply } from 'lucide-react'; // 使用合适的图标
import { Label } from "@/components/ui/label";
import { Input } from "@/components/ui/input"; // 使用 Input 或 Textarea 根据需要
import type { ChatOutputNodeData, CustomNodeProps } from './types';
const ChatOutputNodeComponent = ({ id, data, isConnectable }: CustomNodeProps<ChatOutputNodeData>) => {
const { setNodes } = useReactFlow<ChatOutputNodeData>();
// 处理文本变化 (如果需要可编辑)
const handleTextChange = useCallback((event: ChangeEvent<HTMLInputElement>) => {
const newText = event.target.value;
setNodes((nds: Node<ChatOutputNodeData>[]) =>
nds.map((node) => {
if (node.id === id) {
return { ...node, data: { ...node.data, displayText: newText } };
}
return node;
})
);
}, [id, setNodes]);
return (
<div className="react-flow__node-genericNode bg-gray-50 dark:bg-gray-800 border border-gray-300 dark:border-gray-700 rounded-lg shadow-lg w-96">
{/* 输入 Handle */}
<Handle
type="target"
position={Position.Left}
id="message-input"
isConnectable={isConnectable}
className="!w-3 !h-3 !bg-purple-500 top-1/2" // 使用紫色 Handle
/>
{/* 节点头部 */}
<div className="bg-gray-100 dark:bg-gray-700 p-3 border-b border-gray-300 dark:border-gray-600">
<div className="flex items-center gap-2">
<MessageCircleReply size={18} className="text-gray-700 dark:text-gray-300" />
<strong className="text-gray-800 dark:text-gray-200">Chat Output</strong>
</div>
<p className="text-xs text-gray-600 dark:text-gray-400 mt-1"> Playground </p>
</div>
{/* 节点内容 */}
<div className="p-4 space-y-2">
<div className="nodrag">
<Label htmlFor={`chat-output-text-${id}`} className="text-xs font-semibold text-gray-500 dark:text-gray-400">Text</Label>
{/* 根据截图,这里像是一个 Input可能用于显示模板或结果 */}
<Input
id={`chat-output-text-${id}`}
name="displayText"
value={data.displayText || ''} // 显示传入的数据
onChange={handleTextChange} // 如果需要可编辑
placeholder="等待输入..."
className="mt-1 text-sm h-10 bg-white dark:bg-gray-700"
// readOnly // 如果只是显示,可以设为只读
/>
{/* 或者只是一个简单的文本显示区域 */}
{/* <p className="mt-1 text-sm p-2 border rounded bg-white dark:bg-gray-700 min-h-[40px]">
{data.displayText || '等待输入...'}
</p> */}
</div>
</div>
{/* 输出 Handle (可选,根据是否需要继续传递) */}
{/* <div className="relative p-2 border-t border-gray-300 dark:border-gray-600">
<Label className="text-xs font-semibold text-gray-500 dark:text-gray-400 block text-right pr-7">Message</Label>
<Handle
type="source"
position={Position.Right}
id="message-passthrough"
isConnectable={isConnectable}
className="w-3 h-3 !bg-purple-500 top-1/2"
/>
</div> */}
</div>
);
};
ChatOutputNodeComponent.displayName = 'ChatOutputNode';
export const ChatOutputNode = memo(ChatOutputNodeComponent);

View File

@ -1,104 +1,231 @@
// File: frontend/app/workflow/components/LLMNode.tsx // File: frontend/app/workflow/components/LLMNode.tsx
// Description: 自定义 LLM 节点组件 // Description: 自定义 LLM 节点组件
import React, { useState, useCallback, useEffect, memo, ChangeEvent } from 'react'; import React, {
import { Handle, Position, useUpdateNodeInternals, useReactFlow, Node } from 'reactflow'; useState,
import { BrainCircuit } from 'lucide-react'; useCallback,
useEffect,
memo,
ChangeEvent,
} from "react";
import {
Handle,
Position,
useUpdateNodeInternals,
useReactFlow,
Node,
} from "reactflow";
import { BrainCircuit } from "lucide-react";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { Textarea } from "@/components/ui/textarea"; import { Textarea } from "@/components/ui/textarea";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Slider } from "@/components/ui/slider"; import { Slider } from "@/components/ui/slider";
import { Label } from "@/components/ui/label"; import { Label } from "@/components/ui/label";
import { type LLMNodeData, type CustomNodeProps, availableModels } from './types'; // 导入类型和模型列表 import {
type LLMNodeData,
type CustomNodeProps,
availableModels,
} from "./types"; // 导入类型和模型列表
const LLMNodeComponent = ({ id, data, isConnectable }: CustomNodeProps<LLMNodeData>) => { const LLMNodeComponent = ({
// const [currentData, setCurrentData] = useState<LLMNodeData>(data); id,
data,
isConnectable,
}: CustomNodeProps<LLMNodeData>) => {
// const [currentData, setCurrentData] = useState<LLMNodeData>(data);
const updateNodeInternals = useUpdateNodeInternals(); const updateNodeInternals = useUpdateNodeInternals();
const { setNodes } = useReactFlow<LLMNodeData>(); // 获取 setNodes 方法 const { setNodes } = useReactFlow<LLMNodeData>(); // 获取 setNodes 方法
// 从 props 更新内部状态 (如果外部数据变化)
// useEffect(() => {
// setCurrentData(data);
// }, [data]);
// 处理内部表单变化的通用回调 - 直接更新 React Flow 主状态 // 处理内部表单变化的通用回调 - 直接更新 React Flow 主状态
const handleDataChange = useCallback((key: keyof LLMNodeData, value: any) => { const handleDataChange = useCallback(
// 使用 setNodes 更新特定节点的数据 (key: keyof LLMNodeData, value: any) => {
setNodes((nds: Node<LLMNodeData>[]) => // 显式指定类型 // 使用 setNodes 更新特定节点的数据
nds.map((node) => { setNodes(
if (node.id === id) { (
// 创建一个新的 data 对象 nds: Node<LLMNodeData>[] // 显式指定类型
const updatedData = { ...node.data, [key]: value }; ) =>
return { ...node, data: updatedData }; nds.map((node) => {
} if (node.id === id) {
return node; // 创建一个新的 data 对象
}) const updatedData = { ...node.data, [key]: value };
); return { ...node, data: updatedData };
console.log(`(LLMNode) Node ${id} data updated:`, { [key]: value }); }
}, [id, setNodes]); // 依赖 id 和 setNodes return node;
})
);
console.log(`(LLMNode) Node ${id} data updated:`, { [key]: value });
},
[id, setNodes]
); // 依赖 id 和 setNodes
const handleSliderChange = useCallback((value: number[]) => { handleDataChange('temperature', value[0]); }, [handleDataChange]); const handleSliderChange = useCallback(
const handleSelectChange = useCallback((value: string) => { handleDataChange('model', value); }, [handleDataChange]); (value: number[]) => {
const handleTextChange = useCallback((event: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>) => { handleDataChange("temperature", value[0]);
},
[handleDataChange]
);
const handleSelectChange = useCallback(
(value: string) => {
handleDataChange("model", value);
},
[handleDataChange]
);
const handleTextChange = useCallback(
(event: ChangeEvent<HTMLTextAreaElement | HTMLInputElement>) => {
const { name, value } = event.target; const { name, value } = event.target;
handleDataChange(name as keyof LLMNodeData, value); handleDataChange(name as keyof LLMNodeData, value);
}, [handleDataChange]); },
[handleDataChange]
);
// 检查 inputConnected 状态是否需要更新 (如果外部更新了) // 检查 inputConnected 状态是否需要更新 (如果外部更新了)
// 注意:这个逻辑依赖于父组件正确更新了 data.inputConnected // 注意:这个逻辑依赖于父组件正确更新了 data.inputConnected
useEffect(() => { useEffect(() => {
// 可以在这里添加逻辑,如果 props.data.inputConnected 和 UI 显示不一致时触发更新 // 可以在这里添加逻辑,如果 props.data.inputConnected 和 UI 显示不一致时触发更新
// 但通常连接状态由 onConnect/onEdgesChange 在父组件处理更佳 // 但通常连接状态由 onConnect/onEdgesChange 在父组件处理更佳
// updateNodeInternals(id); // 如果 Handle 显示依赖 inputConnected可能需要调用 // updateNodeInternals(id); // 如果 Handle 显示依赖 inputConnected可能需要调用
}, [data.inputConnected, id, updateNodeInternals]); }, [data.inputConnected, id, updateNodeInternals]);
return ( return (
// 调整宽度,例如 w-96 (24rem) // 调整宽度,例如 w-96 (24rem)
<div className="react-flow__node react-flow__node-genericNode nopan selected selectable draggable bg-purple-50 dark:bg-gray-800 border border-purple-200 dark:border-gray-700 rounded-lg shadow-lg w-96 overflow-hidden"> <div className="react-flow__node-genericNode nopan bg-purple-50 dark:bg-gray-800 border border-purple-200 dark:border-gray-700 rounded-lg shadow-lg w-96">
<div className="bg-purple-100 dark:bg-gray-700 p-3 border-b border-purple-200 dark:border-gray-600"> <div className="bg-purple-100 dark:bg-gray-700 p-3 border-b border-purple-200 dark:border-gray-600">
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<BrainCircuit size={18} className="text-purple-700 dark:text-purple-300" /> <BrainCircuit
<strong className="text-purple-800 dark:text-purple-200">LLM </strong> size={18}
className="text-purple-700 dark:text-purple-300"
/>
<strong className="text-purple-800 dark:text-purple-200">
LLM
</strong>
</div> </div>
<p className="text-xs text-purple-600 dark:text-purple-400 mt-1">使</p> <p className="text-xs text-purple-600 dark:text-purple-400 mt-1">
使
</p>
</div> </div>
<div className="p-4 space-y-3"> <div className="p-4 space-y-3">
<div className="relative nodrag"> <div className="relative nodrag">
<Label className="text-xs font-semibold text-gray-500 dark:text-gray-400"></Label> <Handle
<Handle type="target" position={Position.Left} id="input-text" isConnectable={isConnectable} className="w-3 h-3 !bg-blue-500 top-1/2" /> type="target"
{/* 直接使用 props.data.inputConnected */} position={Position.Left}
{!data.inputConnected && <p className="text-xs text-gray-400 italic mt-1"></p>} id="input-text"
isConnectable={isConnectable}
className="!w-3 !h-3 !bg-blue-500 top-1/2 z-50 !-left-5"
/>
<Label className="text-xs font-semibold text-gray-500 dark:text-gray-400">
</Label>
{/* 直接使用 props.data.inputConnected */}
{!data.inputConnected && (
<p className="text-xs text-gray-400 italic mt-1"></p>
)}
</div>
<div className="nodrag">
<Label
htmlFor={`systemPrompt-${id}`}
className="text-xs font-semibold"
>
</Label>
{/* 直接使用 props.data 和 onChange 回调 */}
<Textarea
id={`systemPrompt-${id}`}
name="systemPrompt"
value={data.systemPrompt}
onChange={handleTextChange}
placeholder="例如:你是一个乐于助人的助手。"
className="mt-1 text-xs min-h-[60px] bg-white dark:bg-gray-700"
rows={3}
/>
</div>
<div className="nodrag">
<Label htmlFor={`model-${id}`} className="text-xs font-semibold">
</Label>
<Select
name="model"
value={data.model}
onValueChange={handleSelectChange}
>
<SelectTrigger
id={`model-${id}`}
className="mt-1 h-8 text-xs bg-white dark:bg-gray-700"
>
<SelectValue placeholder="选择模型" />
</SelectTrigger>
<SelectContent>
{availableModels.map((model) => (
<SelectItem
key={model.value}
value={model.value}
className="text-xs"
>
{model.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="nodrag">
<Label htmlFor={`apiKey-${id}`} className="text-xs font-semibold">
API Key ()
</Label>
<Input
id={`apiKey-${id}`}
name="apiKey"
type="password"
value={data.apiKey || ""}
onChange={handleTextChange}
placeholder="sk-..."
className="mt-1 h-8 text-xs bg-white dark:bg-gray-700"
/>
<p className="text-xs text-red-500 mt-1">
</p>
</div>
<div className="nodrag">
<Label
htmlFor={`temperature-${id}`}
className="text-xs font-semibold"
>
: {data.temperature?.toFixed(1) ?? "N/A"}
</Label>
<Slider
id={`temperature-${id}`}
min={0}
max={1}
step={0.1}
value={[data.temperature ?? 0.7]}
onValueChange={handleSliderChange}
className="mt-2"
/>
<div className="flex justify-between text-xs text-gray-500 dark:text-gray-400 mt-1">
<span></span>
<span></span>
</div>
</div> </div>
<div className="nodrag">
<Label htmlFor={`systemPrompt-${id}`} className="text-xs font-semibold"></Label>
{/* 直接使用 props.data 和 onChange 回调 */}
<Textarea id={`systemPrompt-${id}`} name="systemPrompt" value={data.systemPrompt} onChange={handleTextChange} placeholder="例如:你是一个乐于助人的助手。" className="mt-1 text-xs min-h-[60px] bg-white dark:bg-gray-700" rows={3} />
</div>
<div className="nodrag">
<Label htmlFor={`model-${id}`} className="text-xs font-semibold"></Label>
<Select name="model" value={data.model} onValueChange={handleSelectChange}>
<SelectTrigger id={`model-${id}`} className="mt-1 h-8 text-xs bg-white dark:bg-gray-700"><SelectValue placeholder="选择模型" /></SelectTrigger>
<SelectContent>{availableModels.map(model => (<SelectItem key={model.value} value={model.value} className="text-xs">{model.label}</SelectItem>))}</SelectContent>
</Select>
</div>
<div className="nodrag">
<Label htmlFor={`apiKey-${id}`} className="text-xs font-semibold">API Key ()</Label>
<Input id={`apiKey-${id}`} name="apiKey" type="password" value={data.apiKey || ''} onChange={handleTextChange} placeholder="sk-..." className="mt-1 h-8 text-xs bg-white dark:bg-gray-700" />
<p className="text-xs text-red-500 mt-1"></p>
</div>
<div className="nodrag">
<Label htmlFor={`temperature-${id}`} className="text-xs font-semibold">: {data.temperature?.toFixed(1) ?? 'N/A'}</Label>
<Slider id={`temperature-${id}`} min={0} max={1} step={0.1} value={[data.temperature ?? 0.7]} onValueChange={handleSliderChange} className="mt-2" />
<div className="flex justify-between text-xs text-gray-500 dark:text-gray-400 mt-1"><span></span><span></span></div>
</div>
</div> </div>
<div className="relative p-2 border-t border-purple-200 dark:border-gray-600"> <div className="relative p-2 border-t border-purple-200 dark:border-gray-600">
<Label className="text-xs font-semibold text-gray-500 dark:text-gray-400 block text-right pr-6"></Label> <Label className="text-xs font-semibold text-gray-500 dark:text-gray-400 block text-right pr-6">
<Handle type="source" position={Position.Right} id="output-message" isConnectable={isConnectable} className="w-3 h-3 !bg-green-500 top-1/2" />
</div> </Label>
<Handle
type="source"
position={Position.Right}
id="output-message"
isConnectable={isConnectable}
className="!w-3 !h-3 !bg-green-500 top-1/2 z-50"
/>
</div>
</div> </div>
); );
}; };
LLMNodeComponent.displayName = 'LLMNode'; LLMNodeComponent.displayName = "LLMNode";
export const LLMNode = memo(LLMNodeComponent); export const LLMNode = memo(LLMNodeComponent);

View File

@ -2,7 +2,7 @@
// Description: 侧边栏组件,用于拖放节点 // Description: 侧边栏组件,用于拖放节点
import React, { DragEvent } from 'react'; import React, { DragEvent } from 'react';
import { MessageSquareText, BrainCircuit, Database, LogOut, Play, CheckCircle } from 'lucide-react'; import { MessageSquareText, BrainCircuit, Database, LogOut, Play, CheckCircle, MessageCircleQuestion, MessageCircleReply } from 'lucide-react';
// 定义可拖拽的节点类型 (从 page.tsx 移动过来) // 定义可拖拽的节点类型 (从 page.tsx 移动过来)
const nodeTypesForPalette = [ const nodeTypesForPalette = [
@ -11,6 +11,8 @@ const nodeTypesForPalette = [
{ type: 'llmNode', label: 'LLM 调用', icon: BrainCircuit, defaultData: { model: 'gpt-3.5-turbo', temperature: 0.7, systemPrompt: '你是一个乐于助人的 AI 助手。' } }, { type: 'llmNode', label: 'LLM 调用', icon: BrainCircuit, defaultData: { model: 'gpt-3.5-turbo', temperature: 0.7, systemPrompt: '你是一个乐于助人的 AI 助手。' } },
{ type: 'ragNode', label: 'RAG 查询', icon: Database, defaultData: { query: '...', knowledgeBase: 'default' } }, { type: 'ragNode', label: 'RAG 查询', icon: Database, defaultData: { query: '...', knowledgeBase: 'default' } },
{ type: 'outputNode', label: '结束流程', icon: CheckCircle, defaultData: { label: '结束' } }, { type: 'outputNode', label: '结束流程', icon: CheckCircle, defaultData: { label: '结束' } },
{ type: 'chatInputNode', label: 'Chat Input', icon: MessageCircleQuestion, defaultData: { text: '' } }, // 添加 ChatInput
{ type: 'chatOutputNode', label: 'Chat Output', icon: MessageCircleReply, defaultData: { displayText: '' } }, // 添加 ChatOutput
]; ];

View File

@ -21,6 +21,20 @@ export interface OutputNodeData {
label: string; label: string;
} }
// --- 新增 Chat 节点类型 ---
export interface ChatInputNodeData {
text: string; // 存储输入的文本
}
export interface ChatOutputNodeData {
// ChatOutput 通常接收输入并显示,或者只是一个结束点
// 如果需要配置显示方式,可以在这里添加字段
// 如果它也像截图那样有个输入框,可能是用于显示模板或最终结果
displayText?: string; // 用于显示或配置的文本
inputConnected?: boolean; // 标记输入是否连接
}
// --- 结束新增 ---
// 可以将 NodeProps 包装一下,方便使用 // 可以将 NodeProps 包装一下,方便使用
export type CustomNodeProps<T> = NodeProps<T>; export type CustomNodeProps<T> = NodeProps<T>;

View File

@ -24,6 +24,7 @@ import ReactFlow, {
useReactFlow, useReactFlow,
XYPosition, XYPosition,
BackgroundVariant, BackgroundVariant,
Panel,
} from 'reactflow'; } from 'reactflow';
// 引入自定义节点和侧边栏组件 // 引入自定义节点和侧边栏组件
@ -31,28 +32,21 @@ import { StartNode } from './components/StartNode';
import { OutputNode } from './components/OutputNode'; import { OutputNode } from './components/OutputNode';
import { LLMNode } from './components/LLMNode'; import { LLMNode } from './components/LLMNode';
import { WorkflowSidebar } from './components/WorkflowSidebar'; import { WorkflowSidebar } from './components/WorkflowSidebar';
import { ChatInputNode } from './components/ChatInputNode'; // 引入新节点
import { ChatOutputNode } from './components/ChatOutputNode'; // 引入新节点
// 引入类型 (如果需要) // 引入类型 (如果需要)
// import type { LLMNodeData } from './components/types'; // import type { LLMNodeData } from './components/types';
// 引入 API 函数
import { runWorkflow } from '@/lib/api'; // Import runWorkflow
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
import { toast } from 'sonner';
import { Button } from '@/components/ui/button';
import { LoaderIcon, PlayIcon } from 'lucide-react';
// --- 初始节点和边数据 --- // --- 初始节点和边数据 ---
const initialNodes: Node[] = [ const initialNodes: Node[] = [
{ id: 'start-initial', type: 'startNode', data: { label: '开始' }, position: { x: 150, y: 50 } }, ];
{
id: 'llm-initial',
type: 'llmNode',
data: {
model: 'gpt-3.5-turbo',
temperature: 0.7,
systemPrompt: '你是一个乐于助人的 AI 助手。',
inputConnected: false,
apiKey: '',
},
position: { x: 350, y: 150 }
},
{ id: 'end-initial', type: 'outputNode', data: { label: '结束' }, position: { x: 650, y: 350 } },
];
const initialEdges: Edge[] = []; const initialEdges: Edge[] = [];
// --- 工作流编辑器组件 --- // --- 工作流编辑器组件 ---
@ -60,13 +54,16 @@ function WorkflowEditor() {
const reactFlowWrapper = useRef<HTMLDivElement>(null); // Ref 指向 React Flow 容器 const reactFlowWrapper = useRef<HTMLDivElement>(null); // Ref 指向 React Flow 容器
const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes); const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges); const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);
const { project } = useReactFlow(); // 使用 hook 获取 project 方法 const { screenToFlowPosition } = useReactFlow();
const [isRunning, setIsRunning] = useState(false); // State for run button loading
// --- 注册自定义节点类型 --- // --- 注册自定义节点类型 ---
const nodeTypes = useMemo(() => ({ const nodeTypes = useMemo(() => ({
startNode: StartNode, startNode: StartNode,
outputNode: OutputNode, outputNode: OutputNode,
llmNode: LLMNode, // LLMNode 现在需要一种方式来调用 updateNodeData llmNode: LLMNode, // LLMNode 现在需要一种方式来调用 updateNodeData
chatInputNode: ChatInputNode, // 注册 ChatInputNode
chatOutputNode: ChatOutputNode, // 注册 ChatOutputNode
// inputNode: CustomInputNode, // inputNode: CustomInputNode,
// ragNode: RAGNode, // ragNode: RAGNode,
}), []); // 移除 updateNodeData 依赖,因为它不应该直接传递 }), []); // 移除 updateNodeData 依赖,因为它不应该直接传递
@ -84,6 +81,10 @@ function WorkflowEditor() {
}) })
); );
} }
// 更新 ChatOutputNode 的连接状态 (如果需要)
if (connection.targetHandle === 'message-input' && connection.target) {
setNodes((nds) => nds.map((node) => node.id === connection.target ? { ...node, data: { ...node.data, inputConnected: true } } : node));
}
}, },
[setEdges, setNodes] // 保持 updateNodeData 依赖 [setEdges, setNodes] // 保持 updateNodeData 依赖
); );
@ -103,6 +104,10 @@ function WorkflowEditor() {
}) })
); );
} }
// 更新 ChatOutputNode 的断开状态 (如果需要)
if (edgeToRemove?.targetHandle === 'message-input' && edgeToRemove.target) {
setNodes((nds) => nds.map((node) => node.id === edgeToRemove.target ? { ...node, data: { ...node.data, inputConnected: false } } : node));
}
} }
}); });
onEdgesChange(changes); onEdgesChange(changes);
@ -148,10 +153,9 @@ function WorkflowEditor() {
// 获取鼠标相对于 React Flow 画布的位置 // 获取鼠标相对于 React Flow 画布的位置
// 需要计算鼠标位置相对于 reactFlowWrapper 的偏移量 // 需要计算鼠标位置相对于 reactFlowWrapper 的偏移量
const reactFlowBounds = reactFlowWrapper.current.getBoundingClientRect(); const position = screenToFlowPosition({
const position = project({ x: event.clientX,
x: event.clientX - reactFlowBounds.left, y: event.clientY,
y: event.clientY - reactFlowBounds.top,
}); });
// 创建新节点 // 创建新节点
@ -169,9 +173,45 @@ function WorkflowEditor() {
// 将新节点添加到状态中 // 将新节点添加到状态中
setNodes((nds) => nds.concat(newNode)); setNodes((nds) => nds.concat(newNode));
}, },
[project, setNodes] // 依赖 project 方法和 setNodes [screenToFlowPosition, setNodes] // 依赖 project 方法和 setNodes
); );
// --- 处理工作流运行 ---
const handleRunWorkflow = useCallback(async () => {
setIsRunning(true);
toast.info("正在执行工作流...");
console.log("Running workflow with nodes:", nodes);
console.log("Running workflow with edges:", edges);
try {
const result = await runWorkflow(nodes, edges); // Call the API function
if (result.success && result.output !== undefined && result.output_node_id) {
toast.success(result.message || "工作流执行成功!");
console.log("Workflow output:", result.output);
// 更新 ChatOutputNode 的数据以显示结果
setNodes((nds) =>
nds.map((node) => {
if (node.id === result.output_node_id && node.type === 'chatOutputNode') {
return { ...node, data: { ...node.data, displayText: result.output } };
}
return node;
})
);
} else {
toast.error(result.message || "工作流执行失败。");
console.error("Workflow execution failed:", result.message);
}
} catch (error) {
// This catch block might not be necessary if runWorkflow handles errors
toast.error("执行工作流时发生网络错误。");
console.error("API call error:", error);
} finally {
setIsRunning(false);
}
}, [nodes, edges, setNodes]); // Depend on nodes and edges
return ( return (
<div className="flex flex-col h-full bg-white dark:bg-gray-800 rounded-lg shadow-md overflow-hidden"> <div className="flex flex-col h-full bg-white dark:bg-gray-800 rounded-lg shadow-md overflow-hidden">
<h1 className="text-xl font-semibold p-4 border-b dark:border-gray-700 text-gray-800 dark:text-gray-200"> <h1 className="text-xl font-semibold p-4 border-b dark:border-gray-700 text-gray-800 dark:text-gray-200">
@ -195,6 +235,22 @@ function WorkflowEditor() {
fitView fitView
className="bg-gray-50 dark:bg-gray-900" className="bg-gray-50 dark:bg-gray-900"
> >
{/* 使用 Panel 添加运行按钮到右上角 */}
<Panel position="top-right" className="p-2">
<Button
onClick={handleRunWorkflow}
disabled={isRunning}
size="sm"
className="bg-green-600 hover:bg-green-700 text-white"
>
{isRunning ? (
<LoaderIcon className="mr-2 h-4 w-4 animate-spin" />
) : (
<PlayIcon className="mr-2 h-4 w-4" />
)}
</Button>
</Panel>
<Controls /> <Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable /> <MiniMap nodeStrokeWidth={3} zoomable pannable />
<Background gap={16} color="#ccc" variant={BackgroundVariant.Dots} /> <Background gap={16} color="#ccc" variant={BackgroundVariant.Dots} />

View File

@ -1,33 +1,51 @@
// File: frontend/lib/api.ts (更新) // File: frontend/lib/api.ts (Update)
// Description: 添加调用助手和会话管理 API 的函数 // Description: 添加运行工作流的 API 函数
import { Assistant, AssistantCreateData, AssistantUpdateData } from "@/types/assistant";
import axios from "axios"; import axios from "axios";
import type { Node, Edge } from "reactflow"; // Import React Flow types
import type {
Assistant,
Session,
Message,
AssistantCreateData,
AssistantUpdateData,
ChatApiResponse,
} from "./types"; // Assuming types are defined
// --- Types --- // --- Types ---
export interface Session { // Workflow Run types (match backend pydantic models)
id: string; interface WorkflowNodeData {
title: string; label?: string | null;
assistant_id: string; text?: string | null;
created_at: string; // ISO date string displayText?: string | null;
updated_at?: string | null; // Add updated_at model?: string | null;
temperature?: number | null;
systemPrompt?: string | null;
// Add other node data fields as needed
[key: string]: any; // Allow extra fields
} }
interface WorkflowNode {
// Message type from backend id: string;
export interface Message { type: string;
id: string; position: { x: number; y: number };
session_id: string; data: WorkflowNodeData;
sender: 'user' | 'ai'; // Or extend with 'system' if needed
text: string;
order: number;
created_at: string; // ISO date string
} }
interface WorkflowEdge {
// 聊天响应类型 id: string;
export interface ChatApiResponse { source: string;
reply: string; target: string;
session_id?: string | null; // 后端返回的新 session id sourceHandle?: string | null;
session_title?: string | null; // 后端返回的新 session title targetHandle?: string | null;
}
interface WorkflowRunPayload {
nodes: WorkflowNode[];
edges: WorkflowEdge[];
}
export interface WorkflowRunResult {
success: boolean;
message?: string | null;
output?: string | null;
output_node_id?: string | null;
} }
// --- API Client Setup --- // --- API Client Setup ---
@ -159,17 +177,61 @@ export const deleteSession = async (sessionId: string): Promise<void> => {
// --- Message API (New) --- // --- Message API (New) ---
/** 获取指定会话的消息列表 */ /** 获取指定会话的消息列表 */
export const getMessagesBySession = async (sessionId: string, limit: number = 100, skip: number = 0): Promise<Message[]> => { export const getMessagesBySession = async (
sessionId: string,
limit: number = 100,
skip: number = 0
): Promise<Message[]> => {
try {
const response = await apiClient.get<Message[]>(
`/messages/session/${sessionId}`,
{
params: { limit, skip },
}
);
return response.data;
} catch (error) {
// Handle 404 specifically if needed (session exists but no messages)
if (axios.isAxiosError(error) && error.response?.status === 404) {
return []; // Return empty list if session not found or no messages
}
throw new Error(handleApiError(error, "getMessagesBySession"));
}
};
// --- Workflow API (New) ---
/**
*
* @param nodes - React Flow
* @param edges - React Flow
* @returns
*/
export const runWorkflow = async (nodes: Node[], edges: Edge[]): Promise<WorkflowRunResult> => {
// Map React Flow nodes/edges to the structure expected by the backend API
const payload: WorkflowRunPayload = {
nodes: nodes.map(n => ({
id: n.id,
type: n.type || 'default', // Ensure type is present
position: n.position,
data: n.data as WorkflowNodeData, // Assume data matches for now
})),
edges: edges.map(e => ({
id: e.id,
source: e.source,
target: e.target,
sourceHandle: e.sourceHandle,
targetHandle: e.targetHandle,
})),
};
try { try {
const response = await apiClient.get<Message[]>(`/messages/session/${sessionId}`, { const response = await apiClient.post<WorkflowRunResult>('/workflow/run', payload);
params: { limit, skip }
});
return response.data; return response.data;
} catch (error) { } catch (error) {
// Handle 404 specifically if needed (session exists but no messages) // Return a failed result structure on API error
if (axios.isAxiosError(error) && error.response?.status === 404) { return {
return []; // Return empty list if session not found or no messages success: false,
} message: handleApiError(error, 'runWorkflow'),
throw new Error(handleApiError(error, 'getMessagesBySession')); };
} }
}; };

49
frontend/lib/types.ts Normal file
View File

@ -0,0 +1,49 @@
// --- Types (从后端模型同步或手动定义) ---
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
export interface Assistant {
id: string;
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 创建助手时发送的数据类型
export interface AssistantCreateData {
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 更新助手时发送的数据类型 (所有字段可选)
export type AssistantUpdateData = Partial<AssistantCreateData>;
// --- Types ---
export interface Session {
id: string;
title: string;
assistant_id: string;
created_at: string; // ISO date string
updated_at?: string | null; // Add updated_at
}
// Message type from backend
export interface Message {
id: string;
session_id: string;
sender: "user" | "ai"; // Or extend with 'system' if needed
text: string;
order: number;
created_at: string; // ISO date string
}
// 聊天响应类型
export interface ChatApiResponse {
reply: string;
session_id?: string | null; // 后端返回的新 session id
session_title?: string | null; // 后端返回的新 session title
}