86 lines
3.7 KiB
Python
86 lines
3.7 KiB
Python
# File: backend/app/flow_components/base.py (New)
|
|
# Description: Base class for all workflow node components
|
|
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Dict, Any, Optional, Type, ClassVar
|
|
from abc import ABC, abstractmethod
|
|
from sqlalchemy.ext.asyncio import AsyncSession # Import if components need DB access
|
|
|
|
# --- Input/Output Field Definitions ---
|
|
# These classes help define the expected inputs and outputs for handles/UI fields
|
|
class InputField(BaseModel):
|
|
name: str # Internal name/key, matches handleId or data key
|
|
display_name: str # User-friendly name for UI
|
|
info: Optional[str] = None # Tooltip/description
|
|
field_type: str # e.g., 'str', 'int', 'float', 'bool', 'dict', 'code', 'prompt', 'llm', 'message'
|
|
required: bool = True
|
|
value: Any = None # Default value for UI fields
|
|
is_handle: bool = False # True if this input comes from a Handle connection
|
|
# Add more metadata as needed (e.g., options for dropdowns, range for sliders)
|
|
options: Optional[List[str]] = None
|
|
range_spec: Optional[Dict[str, float]] = None # e.g., {'min': 0, 'max': 1, 'step': 0.1}
|
|
|
|
class OutputField(BaseModel):
|
|
name: str # Internal name/key, matches handleId
|
|
display_name: str
|
|
field_type: str # Data type of the output handle
|
|
info: Optional[str] = None
|
|
|
|
# --- Base Component Class ---
|
|
class BaseComponent(ABC, BaseModel):
|
|
# Class variables for metadata (can be overridden by subclasses)
|
|
display_name: ClassVar[str] = "Base Component"
|
|
description: ClassVar[str] = "A base component for workflow nodes."
|
|
icon: ClassVar[Optional[str]] = None # Icon name (e.g., from Lucide)
|
|
name: ClassVar[str] # Unique internal name/type identifier (matches frontend node type)
|
|
|
|
# Instance variable to hold node-specific data from frontend
|
|
node_data: Dict[str, Any] = {}
|
|
|
|
# Class variables defining inputs and outputs
|
|
inputs: ClassVar[List[InputField]] = []
|
|
outputs: ClassVar[List[OutputField]] = []
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True # Allow complex types like AsyncSession if needed
|
|
|
|
@abstractmethod
|
|
async def run(
|
|
self,
|
|
inputs: Dict[str, Any], # Resolved inputs from UI data and connected nodes
|
|
db: Optional[AsyncSession] = None # Pass DB session if needed
|
|
) -> Dict[str, Any]: # Return dictionary of output values keyed by output name
|
|
"""Executes the component's logic."""
|
|
pass
|
|
|
|
def validate_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Validates if required inputs are present."""
|
|
resolved_inputs = {}
|
|
missing = []
|
|
for field in self.inputs:
|
|
# Combine node_data (UI config) and resolved inputs from connections
|
|
value = inputs.get(field.name, self.node_data.get(field.name, field.value))
|
|
|
|
if field.required and value is None:
|
|
missing.append(field.display_name)
|
|
# TODO: Add type validation based on field.field_type
|
|
resolved_inputs[field.name] = value
|
|
|
|
if missing:
|
|
raise ValueError(f"节点 '{self.display_name}' 缺少必需的输入: {', '.join(missing)}")
|
|
return resolved_inputs
|
|
|
|
|
|
# --- Component Registry ---
|
|
# Simple dictionary to map node type strings to component classes
|
|
component_registry: Dict[str, Type[BaseComponent]] = {}
|
|
|
|
def register_component(cls: Type[BaseComponent]):
|
|
"""Decorator to register component classes."""
|
|
if not hasattr(cls, 'name') or not cls.name:
|
|
raise ValueError(f"Component class {cls.__name__} must have a 'name' attribute.")
|
|
if cls.name in component_registry:
|
|
print(f"警告: 组件 '{cls.name}' 被重复注册。")
|
|
component_registry[cls.name] = cls
|
|
print(f"已注册组件: {cls.name}")
|
|
return cls |