diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index f3602901aa5..a46a7d9eb66 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -15,7 +15,7 @@ from .function_node import * from .question_node import * from .reranker_node import * - +from .loop_node import * from .document_extract_node import * from .image_understand_step_node import * from .image_generate_step_node import * @@ -31,7 +31,7 @@ BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, BaseImageUnderstandNode, BaseFormNode, BaseSpeechToTextNode, BaseTextToSpeechNode, - BaseImageGenerateNode, BaseVariableAssignNode] + BaseImageGenerateNode, BaseVariableAssignNode, BaseLoopNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/loop_node/__init__.py b/apps/application/flow/step_node/loop_node/__init__.py new file mode 100644 index 00000000000..a5f59372be7 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2025/3/11 18:24 + @desc: +""" +from .impl import * diff --git a/apps/application/flow/step_node/loop_node/i_loop_node.py b/apps/application/flow/step_node/loop_node/i_loop_node.py new file mode 100644 index 00000000000..dfa0b3301e4 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/i_loop_node.py @@ -0,0 +1,57 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_loop_node.py + @date:2025/3/11 18:19 + @desc: +""" +from typing import Type + +from application.flow.i_step_node import INode, NodeResult +from rest_framework import serializers + +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage +from django.utils.translation import gettext_lazy as _ + + +class ILoopNodeSerializer(serializers.Serializer): + loop_type = serializers.CharField(required=True, error_messages=ErrMessage.char(_("loop_type"))) + array = serializers.ListField(required=False, allow_null=True, + error_messages=ErrMessage.char(_("array"))) + number = serializers.IntegerField(required=False, allow_null=True, + error_messages=ErrMessage.char(_("number"))) + loop_body = serializers.DictField(required=True, error_messages=ErrMessage.char("循环体")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + loop_type = self.data.get('loop_type') + if loop_type == 'ARRAY': + array = self.data.get('array') + if array is None or len(array) == 0: + message = _('{field}, this field is required.', field='array') + raise AppApiException(500, message) + elif loop_type == 'NUMBER': + number = self.data.get('number') + if number is None: + message = _('{field}, this field is required.', field='number') + raise AppApiException(500, message) + + +class ILoopNode(INode): + type = 'loop-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ILoopNodeSerializer + + def _run(self): + array = self.node_params_serializer.data.get('array') + if self.node_params_serializer.data.get('loop_type') == 'ARRAY': + array = self.workflow_manage.get_reference_field( + array[0], + array[1:]) + return self.execute(**{**self.node_params_serializer.data, "array": array}, **self.flow_params_serializer.data) + + def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/loop_node/impl/__init__.py b/apps/application/flow/step_node/loop_node/impl/__init__.py new file mode 100644 index 00000000000..3cd082322a1 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2025/3/11 18:24 + @desc: +""" +from .base_loop_node import BaseLoopNode diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py new file mode 100644 index 00000000000..6730e0daae5 --- /dev/null +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -0,0 +1,195 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_loop_node.py + @date:2025/3/11 18:24 + @desc: +""" +import time +from typing import Dict + +from application.flow.i_step_node import NodeResult, WorkFlowPostHandler, INode +from application.flow.step_node.loop_node.i_loop_node import ILoopNode +from application.flow.tools import Reasoning +from common.handle.impl.response.loop_to_response import LoopToResponse + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str, + reasoning_content: str): + node.context['answer'] = answer + node.context['run_time'] = time.time() - node.context['start_time'] + node.context['reasoning_content'] = reasoning_content + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + node.answer_text = answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + + response = node_variable.get('result') + workflow_manage = node_variable.get('workflow_manage') + answer = '' + reasoning_content = '' + for chunk in response: + content_chunk = chunk.get('content', '') + reasoning_content_chunk = chunk.get('reasoning_content', '') + reasoning_content += reasoning_content_chunk + answer += content_chunk + yield {'content': content_chunk, + 'reasoning_content': reasoning_content_chunk} + runtime_details = workflow_manage.get_runtime_details() + _write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + model_setting = node.context.get('model_setting', + {'reasoning_content_enable': False, 'reasoning_content_end': '', + 'reasoning_content_start': ''}) + reasoning = Reasoning(model_setting.get('reasoning_content_start'), model_setting.get('reasoning_content_end')) + reasoning_result = reasoning.get_reasoning_content(response) + reasoning_result_end = reasoning.get_end_reasoning_content() + content = reasoning_result.get('content') + reasoning_result_end.get('content') + if 'reasoning_content' in response.response_metadata: + reasoning_content = response.response_metadata.get('reasoning_content', '') + else: + reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content') + _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) + + +def loop_number(number: int, workflow_manage_new_instance, node: INode): + loop_global_data = {} + for index in range(number): + """ + 指定次数循环 + @return: + """ + instance = workflow_manage_new_instance({'index': index}, loop_global_data) + response = instance.stream() + answer = '' + reasoning_content = '' + for chunk in response: + content_chunk = chunk.get('content', '') + reasoning_content_chunk = chunk.get('reasoning_content', '') + reasoning_content += reasoning_content_chunk + answer += content_chunk + yield chunk + loop_global_data = instance.context + + +def loop_array(array, workflow_manage_new_instance, node: INode): + loop_global_data = {} + loop_execute_details = [] + for item, index in zip(array, range(len(array))): + """ + 指定次数循环 + @return: + """ + instance = workflow_manage_new_instance({'index': index, 'item': item}, loop_global_data) + response = instance.stream() + for chunk in response: + yield chunk + loop_global_data = instance.context + runtime_details = instance.get_runtime_details() + loop_execute_details.append(runtime_details) + node.context['loop_execute_details'] = loop_execute_details + + +def get_write_context(loop_type, array, number, loop_body, stream): + def inner_write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + if loop_type == 'ARRAY': + return loop_array(array, node_variable['workflow_manage_new_instance'], node) + return loop_number(number, node_variable['workflow_manage_new_instance'], node) + + return inner_write_context + + +class LoopWorkFlowPostHandler(WorkFlowPostHandler): + def handler(self, chat_id, + chat_record_id, + answer, + workflow): + pass + + +class BaseLoopNode(ILoopNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.answer_text = str(details.get('result')) + + def execute(self, loop_type, array, number, loop_body, stream, **kwargs) -> NodeResult: + from application.flow.workflow_manage import WorkflowManage, Flow + def workflow_manage_new_instance(start_data, global_data): + workflow_manage = WorkflowManage(Flow.new_instance(loop_body), self.workflow_manage.params, + LoopWorkFlowPostHandler( + self.workflow_manage.work_flow_post_handler.chat_info, + self.workflow_manage.work_flow_post_handler.client_id, + self.workflow_manage.work_flow_post_handler.client_type) + , base_to_response=LoopToResponse(), + start_data=start_data, + form_data=global_data) + + return workflow_manage + + return NodeResult({'workflow_manage_new_instance': workflow_manage_new_instance}, {}, + _write_context=get_write_context(loop_type, array, number, loop_body, stream)) + + def loop_number(self, number: int, loop_body, stream): + for index in range(number): + """ + 指定次数循环 + @return: + """ + from application.flow.workflow_manage import WorkflowManage, Flow + workflow_manage = WorkflowManage(Flow.new_instance(loop_body), self.workflow_manage.params, + LoopWorkFlowPostHandler( + self.workflow_manage.work_flow_post_handler.chat_info + , + self.workflow_manage.work_flow_post_handler.client_id, + self.workflow_manage.work_flow_post_handler.client_type) + , base_to_response=LoopToResponse(), + start_data={'index': index}) + result = workflow_manage.stream() + return NodeResult({"result": result, "workflow_manage": workflow_manage}, {}, + _write_context=write_context_stream) + pass + + def loop_array(self, array, loop_body, stream): + """ + 循环数组 + @return: + """ + pass + + def loop_loop(self, loop_body, stream): + """ + 无线循环 + @return: + """ + pass + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": self.context.get('result'), + "params": self.context.get('params'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index a74e5cc8e96..5107d4ce2c8 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -88,7 +88,7 @@ def execute(self, dataset_id_list, dataset_setting, question, 'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')], 'data': '\n'.join( [f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in - paragraph_list])[0:dataset_setting.get('max_paragraph_char_number', 5000)], + result])[0:dataset_setting.get('max_paragraph_char_number', 5000)], 'directly_return': '\n'.join( [paragraph.get('content') for paragraph in result if diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index bf5203274eb..d4a54a25b13 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -33,7 +33,10 @@ def get_global_variable(node): class BaseStartStepNode(IStarNode): def save_context(self, details, workflow_manage): base_node = self.workflow_manage.get_base_node() - default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + default_global_variable = {} + if base_node is not None: + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} self.context['question'] = details.get('question') self.context['run_time'] = details.get('run_time') @@ -50,7 +53,9 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: def execute(self, question, **kwargs) -> NodeResult: base_node = self.workflow_manage.get_base_node() - default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + default_global_variable = {} + if base_node is not None: + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) workflow_variable = {**default_global_variable, **get_global_variable(self)} """ 开始节点 初始化全局变量 @@ -59,7 +64,9 @@ def execute(self, question, **kwargs) -> NodeResult: 'question': question, 'image': self.workflow_manage.image_list, 'document': self.workflow_manage.document_list, - 'audio': self.workflow_manage.audio_list + 'audio': self.workflow_manage.audio_list, + **self.workflow_manage.start_data + } return NodeResult(node_variable, workflow_variable) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index be91f69be9e..ef2272a5e73 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -239,9 +239,11 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl document_list=None, audio_list=None, start_node_id=None, - start_node_data=None, chat_record=None, child_node=None): + start_node_data=None, chat_record=None, child_node=None, start_data=None): if form_data is None: form_data = {} + if start_data is None: + start_data = {} if image_list is None: image_list = [] if document_list is None: @@ -272,6 +274,7 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl self.field_list = [] self.global_field_list = [] self.init_fields() + self.start_data = start_data if start_node_id is not None: self.load_node(chat_record, start_node_id, start_node_data) else: @@ -338,6 +341,12 @@ def get_node_params(n): node.node_chunk.end() self.node_context.append(node) + def stream(self): + close_old_connections() + language = get_language() + self.run_chain_async(self.start_node, None, language) + return self.await_result() + def run(self): close_old_connections() language = get_language() @@ -801,6 +810,8 @@ def get_base_node(self): @return: """ base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] + if len(base_node_list) == 0: + return None return base_node_list[0] def get_node_cls_by_id(self, node_id, up_node_id_list=None, diff --git a/apps/common/handle/impl/response/loop_to_response.py b/apps/common/handle/impl/response/loop_to_response.py new file mode 100644 index 00000000000..7e6553ab757 --- /dev/null +++ b/apps/common/handle/impl/response/loop_to_response.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: LoopToResponse.py + @date:2025/3/12 17:21 + @desc: +""" +import json + +from common.handle.impl.response.system_to_response import SystemToResponse + + +class LoopToResponse(SystemToResponse): + + def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, + completion_tokens, + prompt_tokens, other_params: dict = None): + if other_params is None: + other_params = {} + return {'chat_id': str(chat_id), 'chat_record_id': str(chat_record_id), 'operate': True, + 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, + 'is_end': is_end, + 'usage': {'completion_tokens': completion_tokens, + 'prompt_tokens': prompt_tokens, + 'total_tokens': completion_tokens + prompt_tokens}, + **other_params} diff --git a/ui/src/enums/workflow.ts b/ui/src/enums/workflow.ts index bcd83a05929..b8f84c5df96 100644 --- a/ui/src/enums/workflow.ts +++ b/ui/src/enums/workflow.ts @@ -16,5 +16,7 @@ export enum WorkflowType { FormNode = 'form-node', TextToSpeechNode = 'text-to-speech-node', SpeechToTextNode = 'speech-to-text-node', - ImageGenerateNode = 'image-generate-node' + ImageGenerateNode = 'image-generate-node', + LoopNode = 'loop-node', + LoopBodyNode = 'loop-body-node' } diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index 0d065f9d8ec..aaac42f1739 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -7,7 +7,7 @@ >
-
+
{ } function clickNodes(item: any) { + console.log('clickNodes', item) const width = item.properties.width ? item.properties.width : 214 const nodeModel = props.nodeModel.graphModel.addNode({ type: item.type, diff --git a/ui/src/workflow/common/app-node.ts b/ui/src/workflow/common/app-node.ts index c1739aca6b0..4f054093b81 100644 --- a/ui/src/workflow/common/app-node.ts +++ b/ui/src/workflow/common/app-node.ts @@ -46,7 +46,6 @@ class AppNode extends HtmlResize.view { getNodesName(number + 1) } } - props.model.properties.config = nodeDict[props.model.type].properties.config if (props.model.properties.height) { props.model.height = props.model.properties.height } @@ -115,12 +114,11 @@ class AppNode extends HtmlResize.view { } else { isConnect = this.props.graphModel.edges.some((edge) => edge.sourceAnchorId == anchorData.id) } - return lh( 'foreignObject', { ...anchorData, - x: x - 10, + x: x - 14, y: y - 12, width: 30, height: 30 @@ -134,7 +132,7 @@ class AppNode extends HtmlResize.view { } }, dangerouslySetInnerHTML: { - __html: isConnect + __html: (type == 'children' ? true : isConnect) ? ` diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index 6554c9b7db4..5eb276f6d5f 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -35,6 +35,34 @@ export const startNode = { showNode: true } } +export const loopStartNode = { + id: WorkflowType.Start, + type: WorkflowType.Start, + x: 480, + y: 3340, + properties: { + height: 364, + stepName: t('views.applicationWorkflow.nodes.startNode.label'), + config: { + fields: [ + { + label: t('views.applicationWorkflow.nodes.startNode.index', '下标'), + value: 'index' + }, + { + label: t('views.applicationWorkflow.nodes.startNode.item', '循环元素'), + value: 'item' + } + ], + globalFields: [] + }, + fields: [{ label: t('views.applicationWorkflow.nodes.startNode.question'), value: 'question' }], + globalFields: [ + { label: t('views.applicationWorkflow.nodes.startNode.currentTime'), value: 'time' } + ], + showNode: true + } +} export const baseNode = { id: WorkflowType.Base, type: WorkflowType.Base, @@ -319,6 +347,65 @@ export const textToSpeechNode = { } } } + +export const loopNode = { + type: WorkflowType.LoopNode, + visible: false, + text: t('views.applicationWorkflow.nodes.loopNode.text', '循环节点'), + label: t('views.applicationWorkflow.nodes.loopNode.label', '循环节点'), + height: 252, + properties: { + stepName: t('views.applicationWorkflow.nodes.loopNode.label', '循环节点'), + workflow: { + edges: [], + nodes: [ + { + x: 480, + y: 3340, + id: 'start-node', + type: 'start-node', + properties: { + config: { + fields: [], + globalFields: [] + }, + fields: [], + height: 361.333, + showNode: true, + stepName: '开始', + globalFields: [] + } + } + ] + }, + config: { + fields: [ + { + label: t('loop.item', '循环参数'), + value: 'item' + }, + { + label: t('common.result'), + value: 'result' + } + ] + } + } +} + +export const loopBodyNode = { + type: WorkflowType.LoopBodyNode, + text: t('views.applicationWorkflow.nodes.loopBodyNode.text', '循环体'), + label: t('views.applicationWorkflow.nodes.loopBodyNode.label', '循环体'), + height: 600, + properties: { + width: 1800, + stepName: t('views.applicationWorkflow.nodes.loopBodyNode.label', '循环体'), + config: { + fields: [] + } + } +} export const menuNodes = [ aiChatNode, imageUnderstandNode, @@ -332,7 +419,8 @@ export const menuNodes = [ documentExtractNode, speechToTextNode, textToSpeechNode, - variableAssignNode + variableAssignNode, + loopNode ] /** @@ -426,7 +514,9 @@ export const nodeDict: any = { [WorkflowType.TextToSpeechNode]: textToSpeechNode, [WorkflowType.SpeechToTextNode]: speechToTextNode, [WorkflowType.ImageGenerateNode]: imageGenerateNode, - [WorkflowType.VariableAssignNode]: variableAssignNode + [WorkflowType.VariableAssignNode]: variableAssignNode, + [WorkflowType.LoopNode]: loopNode, + [WorkflowType.LoopBodyNode]: loopBodyNode } export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' diff --git a/ui/src/workflow/common/loopEdge.ts b/ui/src/workflow/common/loopEdge.ts new file mode 100644 index 00000000000..8f53c8dc43f --- /dev/null +++ b/ui/src/workflow/common/loopEdge.ts @@ -0,0 +1,73 @@ +import { BezierEdge, BezierEdgeModel, h } from '@logicflow/core' + +class CustomEdgeModel2 extends BezierEdgeModel { + getArrowStyle() { + const arrowStyle = super.getArrowStyle() + arrowStyle.offset = 0 + arrowStyle.verticalLength = 0 + return arrowStyle + } + + getEdgeStyle() { + const style = super.getEdgeStyle() + // svg属性 + style.strokeWidth = 2 + style.stroke = '#BBBFC4' + style.offset = 0 + return style + } + /** + * 重写此方法,使保存数据是能带上锚点数据。 + */ + getData() { + const data: any = super.getData() + if (data) { + data.sourceAnchorId = this.sourceAnchorId + data.targetAnchorId = this.targetAnchorId + } + return data + } + /** + * 给边自定义方案,使其支持基于锚点的位置更新边的路径 + */ + updatePathByAnchor() { + // TODO + const sourceNodeModel = this.graphModel.getNodeModelById(this.sourceNodeId) + const sourceAnchor = sourceNodeModel + .getDefaultAnchor() + .find((anchor: any) => anchor.id === this.sourceAnchorId) + + const targetNodeModel = this.graphModel.getNodeModelById(this.targetNodeId) + const targetAnchor = targetNodeModel + .getDefaultAnchor() + .find((anchor: any) => anchor.id === this.targetAnchorId) + if (sourceAnchor && targetAnchor) { + const startPoint = { + x: sourceAnchor.x, + y: sourceAnchor.y - 10 + } + this.updateStartPoint(startPoint) + const endPoint = { + x: targetAnchor.x, + y: targetAnchor.y + 3 + } + + this.updateEndPoint(endPoint) + } + + // 这里需要将原有的pointsList设置为空,才能触发bezier的自动计算control点。 + this.pointsList = [] + this.initPoints() + } + setAttributes(): void { + super.setAttributes() + this.isHitable = true + this.zIndex = 0 + } +} + +export default { + type: 'loop-edge', + view: BezierEdge, + model: CustomEdgeModel2 +} diff --git a/ui/src/workflow/common/shortcut.ts b/ui/src/workflow/common/shortcut.ts index 2b3235b6bea..e437a2f3f20 100644 --- a/ui/src/workflow/common/shortcut.ts +++ b/ui/src/workflow/common/shortcut.ts @@ -91,12 +91,20 @@ export function initDefaultShortcut(lf: LogicFlow, graph: GraphModel) { return } if (elements.edges.length > 0 && elements.nodes.length == 0) { - elements.edges.forEach((edge: any) => lf.deleteEdge(edge.id)) + elements.edges.forEach((edge: any) => { + if (edge.type === 'app-edge') { + lf.deleteEdge(edge.id) + } + }) return } - const nodes = elements.nodes.filter((node) => ['start-node', 'base-node'].includes(node.type)) + const nodes = elements.nodes.filter((node) => + ['start-node', 'base-node', 'loop-body-node'].includes(node.type) + ) if (nodes.length > 0) { - MsgError(`${nodes[0].properties?.stepName}${t('views.applicationWorkflow.delete.deleteMessage')}`) + MsgError( + `${nodes[0].properties?.stepName}${t('views.applicationWorkflow.delete.deleteMessage')}` + ) return } MsgConfirm(t('common.tip'), t('views.applicationWorkflow.delete.confirmTitle'), { @@ -107,7 +115,17 @@ export function initDefaultShortcut(lf: LogicFlow, graph: GraphModel) { if (graph.textEditElement) return true elements.edges.forEach((edge: any) => lf.deleteEdge(edge.id)) - elements.nodes.forEach((node: any) => lf.deleteNode(node.id)) + elements.nodes.forEach((node: any) => { + if (node.type === 'loop-node') { + const next = lf.getNodeOutgoingNode(node.id) + next.forEach((n: any) => { + if (n.type === 'loop-body-node') { + lf.deleteNode(n.id) + } + }) + } + lf.deleteNode(node.id) + }) }) return false diff --git a/ui/src/workflow/common/teleport.ts b/ui/src/workflow/common/teleport.ts index 6f88b414fa8..409f2b7fbe1 100644 --- a/ui/src/workflow/common/teleport.ts +++ b/ui/src/workflow/common/teleport.ts @@ -65,9 +65,9 @@ export function getTeleport(): any { // 比对当前界面显示的flowId,只更新items[当前页面flowId:nodeId]的数据 // 比如items[0]属于Page1的数据,那么Page2无论active=true/false,都无法执行items[0] - if (id.startsWith(props.flowId)) { - children.push(items[id]) - } + // if (id.startsWith(props.flowId)) { + children.push(items[id]) + // } }) return h( Fragment, diff --git a/ui/src/workflow/common/validate.ts b/ui/src/workflow/common/validate.ts index cc9378bc317..a131f1114e8 100644 --- a/ui/src/workflow/common/validate.ts +++ b/ui/src/workflow/common/validate.ts @@ -10,7 +10,8 @@ const end_nodes: Array = [ WorkflowType.Application, WorkflowType.SpeechToTextNode, WorkflowType.TextToSpeechNode, - WorkflowType.ImageGenerateNode, + WorkflowType.ImageGenerateNode, + WorkflowType.LoopBodyNode ] export class WorkFlowInstance { nodes diff --git a/ui/src/workflow/index.vue b/ui/src/workflow/index.vue index 03361bcec3f..e4e1c6f7ada 100644 --- a/ui/src/workflow/index.vue +++ b/ui/src/workflow/index.vue @@ -8,6 +8,7 @@ import LogicFlow from '@logicflow/core' import { ref, onMounted, computed } from 'vue' import AppEdge from './common/edge' +import loopEdge from './common/loopEdge' import Control from './common/NodeControl.vue' import { baseNodes } from '@/workflow/common/data' import '@logicflow/extension/lib/style/index.css' @@ -93,7 +94,11 @@ const renderGraphData = (data?: any) => { flowId.value = lf.value.graphModel.flowId }) initDefaultShortcut(lf.value, lf.value.graphModel) - lf.value.batchRegister([...Object.keys(nodes).map((key) => nodes[key].default), AppEdge]) + lf.value.batchRegister([ + ...Object.keys(nodes).map((key) => nodes[key].default), + AppEdge, + loopEdge + ]) lf.value.setDefaultEdgeType('app-edge') lf.value.render(data ? data : {}) @@ -117,7 +122,18 @@ const validate = () => { return Promise.all(lf.value.graphModel.nodes.map((element: any) => element?.validate?.())) } const getGraphData = () => { - return lf.value.getGraphData() + const graph_data = lf.value.getGraphData() + graph_data.nodes = graph_data.nodes.filter((node: any) => { + if (node.type === 'loop-body-node') { + const node_model = lf.value.getNodeModelById(node.id) + console.log(node_model) + node_model.set_loop_body() + return false + } + return true + }) + graph_data.edges = graph_data.edges.filter((node: any) => node.type !== 'loop-edge') + return graph_data } const onmousedown = (shapeItem: ShapeItem) => { diff --git a/ui/src/workflow/nodes/loop-body-node/LoopBodyContainer.vue b/ui/src/workflow/nodes/loop-body-node/LoopBodyContainer.vue new file mode 100644 index 00000000000..66323f03771 --- /dev/null +++ b/ui/src/workflow/nodes/loop-body-node/LoopBodyContainer.vue @@ -0,0 +1,222 @@ +
+
+
+
+
+ +

{{ nodeModel.properties.stepName }}

+
+
+ +
+ + + +
+
+
+
+ + + + + + + + + +
+ + + diff --git a/ui/src/workflow/nodes/loop-body-node/index.ts b/ui/src/workflow/nodes/loop-body-node/index.ts new file mode 100644 index 00000000000..df3d84e8d12 --- /dev/null +++ b/ui/src/workflow/nodes/loop-body-node/index.ts @@ -0,0 +1,39 @@ +import LoopNode from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +import { WorkflowType } from '@/enums/workflow' +class LoopBodyNodeView extends AppNode { + constructor(props: any) { + super(props, LoopNode) + } +} +class LoopBodyModel extends AppNodeModel { + refreshBranch() { + // 更新节点连接边的path + this.incoming.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + this.outgoing.edges.forEach((edge: any) => { + edge.updatePathByAnchor() + }) + } + getDefaultAnchor() { + const { id, x, y, width, height } = this + const showNode = this.properties.showNode === undefined ? true : this.properties.showNode + const anchors: any = [] + anchors.push({ + edgeAddable: false, + x: x, + y: y - height / 2 + 10, + id: `${id}_children`, + type: 'children' + }) + + return anchors + } +} +export default { + type: 'loop-body-node', + model: LoopBodyModel, + view: LoopBodyNodeView +} diff --git a/ui/src/workflow/nodes/loop-body-node/index.vue b/ui/src/workflow/nodes/loop-body-node/index.vue new file mode 100644 index 00000000000..99dbb0f8b9b --- /dev/null +++ b/ui/src/workflow/nodes/loop-body-node/index.vue @@ -0,0 +1,97 @@ + + + diff --git a/ui/src/workflow/nodes/loop-node/index.ts b/ui/src/workflow/nodes/loop-node/index.ts new file mode 100644 index 00000000000..ac7467d5d83 --- /dev/null +++ b/ui/src/workflow/nodes/loop-node/index.ts @@ -0,0 +1,56 @@ +import LoopNode from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +import { WorkflowType } from '@/enums/workflow' +class LoopNodeView extends AppNode { + constructor(props: any) { + super(props, LoopNode) + } +} +class LoopModel extends AppNodeModel { + refreshBranch() { + // 更新节点连接边的path + this.incoming.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + this.outgoing.edges.forEach((edge: any) => { + edge.updatePathByAnchor() + }) + } + getDefaultAnchor() { + const { id, x, y, width, height } = this + const showNode = this.properties.showNode === undefined ? true : this.properties.showNode + const anchors: any = [] + + if (this.type !== WorkflowType.Base) { + if (this.type !== WorkflowType.Start) { + anchors.push({ + x: x - width / 2 + 10, + y: showNode ? y : y - 15, + id: `${id}_left`, + edgeAddable: false, + type: 'left' + }) + } + anchors.push({ + x: x + width / 2 - 10, + y: showNode ? y : y - 15, + id: `${id}_right`, + type: 'right' + }) + } + anchors.push({ + x: x, + y: y + height / 2 - 25, + id: `${id}_children`, + type: 'children' + }) + + return anchors + } +} +export default { + type: 'loop-node', + model: LoopModel, + view: LoopNodeView +} diff --git a/ui/src/workflow/nodes/loop-node/index.vue b/ui/src/workflow/nodes/loop-node/index.vue new file mode 100644 index 00000000000..16ffdabf33a --- /dev/null +++ b/ui/src/workflow/nodes/loop-node/index.vue @@ -0,0 +1,158 @@ + + + diff --git a/ui/src/workflow/nodes/start-node/index.vue b/ui/src/workflow/nodes/start-node/index.vue index 9414f9251f5..7661f854c11 100644 --- a/ui/src/workflow/nodes/start-node/index.vue +++ b/ui/src/workflow/nodes/start-node/index.vue @@ -1,6 +1,8 @@