diff --git a/src/Tree.tsx b/src/Tree.tsx index 690566e9..febf6ffd 100644 --- a/src/Tree.tsx +++ b/src/Tree.tsx @@ -76,7 +76,7 @@ export type AllowDrop = ( export type DraggableFn = (node: DataNode) => boolean; export type DraggableConfig = { - icon?: React.ReactNode | false; + icon?: ((node: DataNode) => React.ReactNode) | React.ReactNode | false; nodeDraggable?: DraggableFn; }; diff --git a/src/TreeNode.tsx b/src/TreeNode.tsx index 3a51f2cb..5162e312 100644 --- a/src/TreeNode.tsx +++ b/src/TreeNode.tsx @@ -4,7 +4,7 @@ import * as React from 'react'; // @ts-ignore import { TreeContext, TreeContextProps } from './contextTypes'; import Indent from './Indent'; -import { TreeNodeProps } from './interface'; +import { DataNode, TreeNodeProps } from './interface'; import getEntity from './utils/keyUtil'; import { convertNodePropsToEventData } from './utils/treeUtil'; @@ -293,13 +293,15 @@ class InternalTreeNode extends React.Component { + renderDragHandler = (data: DataNode) => { const { context: { draggable, prefixCls }, } = this.props; return draggable?.icon ? ( - {draggable.icon} + + {typeof draggable.icon === 'function' ? draggable.icon(data) : draggable.icon} + ) : null; }; @@ -579,13 +581,8 @@ class InternalTreeNode extends React.Component - - {this.renderDragHandler()} + + {this.renderDragHandler(data)} {this.renderSwitcher()} {this.renderCheckbox()} {this.renderSelector()} diff --git a/tests/TreeDraggable.spec.tsx b/tests/TreeDraggable.spec.tsx index 2d5737cf..97f61fa0 100644 --- a/tests/TreeDraggable.spec.tsx +++ b/tests/TreeDraggable.spec.tsx @@ -1086,6 +1086,35 @@ describe('Tree Draggable', () => { expect(container.querySelectorAll('.handler')).toHaveLength(2); }); + it('render handler icon render function', () => { + const { container } = render( + !node?.children && , + }} + defaultExpandAll + treeData={[ + { + title: 'Parent', + key: 'parent', + children: [ + { + title: 'Child', + key: 'child', + }, + { + title: 'Child1', + key: 'child1', + }, + ], + }, + ]} + />, + ); + + expect(container.querySelectorAll('.handler')).toHaveLength(2); + }); + it('not break with fieldNames', () => { const onDrop = jest.fn();