From a8a8e2fe710ab23e4081732d71fb36ad90ee9be1 Mon Sep 17 00:00:00 2001 From: Faris Durrani Date: Thu, 7 Sep 2023 21:46:56 -0400 Subject: [PATCH] Default layers for Iris --- .../Tabular/components/TabularDnd.tsx | 96 ++++++++++++++----- .../Tabular/constants/tabularConstants.ts | 31 ++---- 2 files changed, 81 insertions(+), 46 deletions(-) diff --git a/frontend/src/features/Train/features/Tabular/components/TabularDnd.tsx b/frontend/src/features/Train/features/Tabular/components/TabularDnd.tsx index b818d1ebd..bdd63c281 100644 --- a/frontend/src/features/Train/features/Tabular/components/TabularDnd.tsx +++ b/frontend/src/features/Train/features/Tabular/components/TabularDnd.tsx @@ -9,7 +9,7 @@ import { TextField, Typography, } from "@mui/material"; -import React, { useCallback, useEffect } from "react"; +import React, { useEffect } from "react"; import ReactFlow, { Background, BackgroundVariant, @@ -24,10 +24,15 @@ import ReactFlow, { useNodesState, } from "reactflow"; import "reactflow/dist/style.css"; -import { ALL_LAYERS, STEP_SETTINGS } from "../constants/tabularConstants"; +import { + ALL_LAYERS, + DEFAULT_LAYERS, + STEP_SETTINGS, +} from "../constants/tabularConstants"; import { ParameterData } from "../types/tabularTypes"; import assert from "assert"; -import { randomUUID } from "crypto"; +import { nanoid } from "nanoid/non-secure"; +import { toast } from "react-toastify"; interface TabularDndProps { setLayers?: (layers: ParameterData["layers"]) => void; @@ -35,6 +40,26 @@ interface TabularDndProps { export default function TabularDnd(props: TabularDndProps) { const { setLayers } = props; + + const initialNodes: Node[] = [ + ROOT_NODE, + ...DEFAULT_LAYERS.IRIS.map((layer, i) => ({ + id: `${layer.value}-${i}`, + type: "textUpdater", + position: { + x: DEFAULT_X_POSITION, + y: (i + 1) * 125, + }, + data: { + label: STEP_SETTINGS.PARAMETERS.layers[layer.value].label, + value: layer.value, + parameters: layer.parameters.slice(), + onChange: onChange, + }, + })), + ]; + const initialEdges: Edge[] = createInitialEdges(); + const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes); const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges); @@ -47,6 +72,7 @@ export default function TabularDnd(props: TabularDndProps) { function handleExportLayers(): ParameterData["layers"] { const layers: ParameterData["layers"] = []; + const visited = new Set(["root"]); const directedEdges: Record> = {}; edges.forEach((edge) => { @@ -62,6 +88,12 @@ export default function TabularDnd(props: TabularDndProps) { // appending all layers from the graph while (nextNode) { assert(nextNode.data.value !== "root"); + + if (visited.has(nextNode.id)) { + toast.error("Cycle detected in layers"); + return layers; + } + visited.add(nextNode.id); layers.push({ value: nextNode.data.value, parameters: nextNode.data.parameters || [], @@ -72,7 +104,7 @@ export default function TabularDnd(props: TabularDndProps) { return layers; } - const onChange = useCallback((args: OnChangeArgs) => { + function onChange(args: OnChangeArgs) { setNodes((nds) => nds.map((node) => { if (node.id !== args.id) return node; @@ -82,7 +114,7 @@ export default function TabularDnd(props: TabularDndProps) { return node; }) ); - }, []); + } return ( <> @@ -106,10 +138,10 @@ export default function TabularDnd(props: TabularDndProps) { setNodes((cur) => [ ...cur, { - id: `${value}-${randomUUID()}`, + id: `${value}-${nanoid()}`, type: "textUpdater", position: { - x: Math.random() * 50, + x: DEFAULT_X_POSITION, y: Math.random() * 50, }, data: { @@ -122,7 +154,6 @@ export default function TabularDnd(props: TabularDndProps) { }, }, ]); - console.log(nodes); }} > {value} @@ -243,19 +274,40 @@ interface LayerNodeData { onChange: (args: OnChangeArgs) => void; } -const initialNodes: Node[] = [ - { - id: `root`, - type: "input", - position: { x: 0, y: 0 }, - deletable: false, - data: { - label: "Start", - value: "root", - parameters: [], - onChange: () => undefined, - }, +const DEFAULT_X_POSITION = 10; + +const ROOT_NODE: Node = { + id: `root`, + type: "input", + position: { x: DEFAULT_X_POSITION, y: 0 }, + deletable: false, + data: { + label: "Start", + value: "root", + parameters: [], + onChange: () => undefined, }, -]; -const initialEdges: Edge[] = [{ id: "e1-2", source: "1", target: "2" }]; +}; const nodeTypes: NodeTypes = { textUpdater: TextUpdaterNode }; + +function createInitialEdges(): Edge[] { + const edges: Edge[] = []; + const defaultLayers = DEFAULT_LAYERS.IRIS; + + // connecting root to first layer + edges.push({ + id: `eroot-0`, + source: "root", + target: `${defaultLayers[0].value}-0`, + }); + + // connecting all layers + for (let i = 0; i < defaultLayers.length - 1; i++) { + edges.push({ + id: `e${i}-${i + 1}`, + source: `${defaultLayers[i].value}-${i}`, + target: `${defaultLayers[i + 1].value}-${i + 1}`, + }); + } + return edges; +} diff --git a/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts b/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts index 649309dcc..533b4fa53 100644 --- a/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts +++ b/frontend/src/features/Train/features/Tabular/constants/tabularConstants.ts @@ -144,32 +144,15 @@ export const STEP_SETTINGS = { }, } as const; -type DEFAULT_DATASET_VALUES = - (typeof STEP_SETTINGS)["DATASET"]["defaultDatasets"][number]["value"]; export type ALL_LAYERS = keyof typeof STEP_SETTINGS.PARAMETERS.layers; -export const DEFAULT_LAYERS: Partial<{ - [dataset in DEFAULT_DATASET_VALUES]: { - value: ALL_LAYERS; - parameters: number[]; - }[]; -}> = { +export const DEFAULT_LAYERS: { + IRIS: { value: ALL_LAYERS; parameters: number[] }[]; +} = { IRIS: [ - { - value: "LINEAR", - parameters: [4, 10], - }, - { - value: "RELU", - parameters: [], - }, - { - value: "LINEAR", - parameters: [10, 3], - }, - { - value: "SOFTMAX", - parameters: [-1], - }, + { value: "LINEAR", parameters: [4, 10] }, + { value: "RELU", parameters: [] }, + { value: "LINEAR", parameters: [10, 3] }, + { value: "SOFTMAX", parameters: [-1] }, ], };