2025-05-02 17:31:33 +08:00

86 lines
3.7 KiB
Python

# File: backend/app/flow_components/base.py (New)
# Description: Base class for all workflow node components
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional, Type, ClassVar
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession # Import if components need DB access
# --- Input/Output Field Definitions ---
# These classes help define the expected inputs and outputs for handles/UI fields
class InputField(BaseModel):
name: str # Internal name/key, matches handleId or data key
display_name: str # User-friendly name for UI
info: Optional[str] = None # Tooltip/description
field_type: str # e.g., 'str', 'int', 'float', 'bool', 'dict', 'code', 'prompt', 'llm', 'message'
required: bool = True
value: Any = None # Default value for UI fields
is_handle: bool = False # True if this input comes from a Handle connection
# Add more metadata as needed (e.g., options for dropdowns, range for sliders)
options: Optional[List[str]] = None
range_spec: Optional[Dict[str, float]] = None # e.g., {'min': 0, 'max': 1, 'step': 0.1}
class OutputField(BaseModel):
name: str # Internal name/key, matches handleId
display_name: str
field_type: str # Data type of the output handle
info: Optional[str] = None
# --- Base Component Class ---
class BaseComponent(ABC, BaseModel):
# Class variables for metadata (can be overridden by subclasses)
display_name: ClassVar[str] = "Base Component"
description: ClassVar[str] = "A base component for workflow nodes."
icon: ClassVar[Optional[str]] = None # Icon name (e.g., from Lucide)
name: ClassVar[str] # Unique internal name/type identifier (matches frontend node type)
# Instance variable to hold node-specific data from frontend
node_data: Dict[str, Any] = {}
# Class variables defining inputs and outputs
inputs: ClassVar[List[InputField]] = []
outputs: ClassVar[List[OutputField]] = []
class Config:
arbitrary_types_allowed = True # Allow complex types like AsyncSession if needed
@abstractmethod
async def run(
self,
inputs: Dict[str, Any], # Resolved inputs from UI data and connected nodes
db: Optional[AsyncSession] = None # Pass DB session if needed
) -> Dict[str, Any]: # Return dictionary of output values keyed by output name
"""Executes the component's logic."""
pass
def validate_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Validates if required inputs are present."""
resolved_inputs = {}
missing = []
for field in self.inputs:
# Combine node_data (UI config) and resolved inputs from connections
value = inputs.get(field.name, self.node_data.get(field.name, field.value))
if field.required and value is None:
missing.append(field.display_name)
# TODO: Add type validation based on field.field_type
resolved_inputs[field.name] = value
if missing:
raise ValueError(f"节点 '{self.display_name}' 缺少必需的输入: {', '.join(missing)}")
return resolved_inputs
# --- Component Registry ---
# Simple dictionary to map node type strings to component classes
component_registry: Dict[str, Type[BaseComponent]] = {}
def register_component(cls: Type[BaseComponent]):
"""Decorator to register component classes."""
if not hasattr(cls, 'name') or not cls.name:
raise ValueError(f"Component class {cls.__name__} must have a 'name' attribute.")
if cls.name in component_registry:
print(f"警告: 组件 '{cls.name}' 被重复注册。")
component_registry[cls.name] = cls
print(f"已注册组件: {cls.name}")
return cls