Skip to content

Commit

Permalink
Default layers for Iris
Browse files Browse the repository at this point in the history
  • Loading branch information
farisdurrani committed Sep 8, 2023
1 parent 2a99dd5 commit a8a8e2f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
TextField,
Typography,
} from "@mui/material";
import React, { useCallback, useEffect } from "react";
import React, { useEffect } from "react";
import ReactFlow, {
Background,
BackgroundVariant,
Expand All @@ -24,17 +24,42 @@ import ReactFlow, {
useNodesState,
} from "reactflow";
import "reactflow/dist/style.css";
import { ALL_LAYERS, STEP_SETTINGS } from "../constants/tabularConstants";
import {
ALL_LAYERS,
DEFAULT_LAYERS,
STEP_SETTINGS,
} from "../constants/tabularConstants";
import { ParameterData } from "../types/tabularTypes";
import assert from "assert";
import { randomUUID } from "crypto";
import { nanoid } from "nanoid/non-secure";
import { toast } from "react-toastify";

interface TabularDndProps {
setLayers?: (layers: ParameterData["layers"]) => void;
}

export default function TabularDnd(props: TabularDndProps) {
const { setLayers } = props;

const initialNodes: Node<LayerNodeData>[] = [
ROOT_NODE,
...DEFAULT_LAYERS.IRIS.map((layer, i) => ({
id: `${layer.value}-${i}`,
type: "textUpdater",
position: {
x: DEFAULT_X_POSITION,
y: (i + 1) * 125,
},
data: {
label: STEP_SETTINGS.PARAMETERS.layers[layer.value].label,
value: layer.value,
parameters: layer.parameters.slice(),
onChange: onChange,
},
})),
];
const initialEdges: Edge[] = createInitialEdges();

const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);

Expand All @@ -47,6 +72,7 @@ export default function TabularDnd(props: TabularDndProps) {

function handleExportLayers(): ParameterData["layers"] {
const layers: ParameterData["layers"] = [];
const visited = new Set<string>(["root"]);

const directedEdges: Record<string, Node<LayerNodeData>> = {};
edges.forEach((edge) => {
Expand All @@ -62,6 +88,12 @@ export default function TabularDnd(props: TabularDndProps) {
// appending all layers from the graph
while (nextNode) {
assert(nextNode.data.value !== "root");

if (visited.has(nextNode.id)) {
toast.error("Cycle detected in layers");
return layers;
}
visited.add(nextNode.id);
layers.push({
value: nextNode.data.value,
parameters: nextNode.data.parameters || [],
Expand All @@ -72,7 +104,7 @@ export default function TabularDnd(props: TabularDndProps) {
return layers;
}

const onChange = useCallback((args: OnChangeArgs) => {
function onChange(args: OnChangeArgs) {
setNodes((nds) =>
nds.map((node) => {
if (node.id !== args.id) return node;
Expand All @@ -82,7 +114,7 @@ export default function TabularDnd(props: TabularDndProps) {
return node;
})
);
}, []);
}

return (
<>
Expand All @@ -106,10 +138,10 @@ export default function TabularDnd(props: TabularDndProps) {
setNodes((cur) => [
...cur,
{
id: `${value}-${randomUUID()}`,
id: `${value}-${nanoid()}`,
type: "textUpdater",
position: {
x: Math.random() * 50,
x: DEFAULT_X_POSITION,
y: Math.random() * 50,
},
data: {
Expand All @@ -122,7 +154,6 @@ export default function TabularDnd(props: TabularDndProps) {
},
},
]);
console.log(nodes);
}}
>
{value}
Expand Down Expand Up @@ -243,19 +274,40 @@ interface LayerNodeData {
onChange: (args: OnChangeArgs) => void;
}

const initialNodes: Node<LayerNodeData>[] = [
{
id: `root`,
type: "input",
position: { x: 0, y: 0 },
deletable: false,
data: {
label: "Start",
value: "root",
parameters: [],
onChange: () => undefined,
},
const DEFAULT_X_POSITION = 10;

const ROOT_NODE: Node<LayerNodeData> = {
id: `root`,
type: "input",
position: { x: DEFAULT_X_POSITION, y: 0 },
deletable: false,
data: {
label: "Start",
value: "root",
parameters: [],
onChange: () => undefined,
},
];
const initialEdges: Edge[] = [{ id: "e1-2", source: "1", target: "2" }];
};
const nodeTypes: NodeTypes = { textUpdater: TextUpdaterNode };

function createInitialEdges(): Edge[] {
const edges: Edge[] = [];
const defaultLayers = DEFAULT_LAYERS.IRIS;

// connecting root to first layer
edges.push({
id: `eroot-0`,
source: "root",
target: `${defaultLayers[0].value}-0`,
});

// connecting all layers
for (let i = 0; i < defaultLayers.length - 1; i++) {
edges.push({
id: `e${i}-${i + 1}`,
source: `${defaultLayers[i].value}-${i}`,
target: `${defaultLayers[i + 1].value}-${i + 1}`,
});
}
return edges;
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,15 @@ export const STEP_SETTINGS = {
},
} as const;

type DEFAULT_DATASET_VALUES =
(typeof STEP_SETTINGS)["DATASET"]["defaultDatasets"][number]["value"];
export type ALL_LAYERS = keyof typeof STEP_SETTINGS.PARAMETERS.layers;

export const DEFAULT_LAYERS: Partial<{
[dataset in DEFAULT_DATASET_VALUES]: {
value: ALL_LAYERS;
parameters: number[];
}[];
}> = {
export const DEFAULT_LAYERS: {
IRIS: { value: ALL_LAYERS; parameters: number[] }[];
} = {
IRIS: [
{
value: "LINEAR",
parameters: [4, 10],
},
{
value: "RELU",
parameters: [],
},
{
value: "LINEAR",
parameters: [10, 3],
},
{
value: "SOFTMAX",
parameters: [-1],
},
{ value: "LINEAR", parameters: [4, 10] },
{ value: "RELU", parameters: [] },
{ value: "LINEAR", parameters: [10, 3] },
{ value: "SOFTMAX", parameters: [-1] },
],
};

0 comments on commit a8a8e2f

Please sign in to comment.