diff --git a/src/components/graph/GraphCanvas.vue b/src/components/graph/GraphCanvas.vue index 731e132be..6033a5655 100644 --- a/src/components/graph/GraphCanvas.vue +++ b/src/components/graph/GraphCanvas.vue @@ -46,7 +46,10 @@ import type { RenderedTreeExplorerNode } from '@/types/treeExplorerTypes' import { useNodeBookmarkStore } from '@/stores/nodeBookmarkStore' import { useCanvasStore } from '@/stores/graphStore' import { ComfyModelDef } from '@/stores/modelStore' -import { useModelToNodeStore } from '@/stores/modelToNodeStore' +import { + ModelNodeProvider, + useModelToNodeStore +} from '@/stores/modelToNodeStore' import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue' const emit = defineEmits(['ready']) @@ -143,15 +146,33 @@ onMounted(async () => { comfyApp.addNodeOnGraph(nodeDef, { pos }) } else if (node.data instanceof ComfyModelDef) { const model = node.data - const provider = modelToNodeStore.getNodeProvider(model.directory) - if (provider) { - const pos = comfyApp.clientPosToCanvasPos([ - loc.clientX - 20, - loc.clientY - ]) - const node = comfyApp.addNodeOnGraph(provider.nodeDef, { pos }) - const widget = node.widgets.find( - (widget) => widget.name === provider.key + const pos = comfyApp.clientPosToCanvasPos([loc.clientX, loc.clientY]) + const nodeAtPos = comfyApp.graph.getNodeOnPos(pos[0], pos[1]) + let targetProvider: ModelNodeProvider | null = null + let targetGraphNode: LGraphNode | null = null + if (nodeAtPos) { + const providers = modelToNodeStore.getAllNodeProviders( + model.directory + ) + for (const provider of providers) { + if (provider.nodeDef.name === nodeAtPos.comfyClass) { + targetGraphNode = nodeAtPos + targetProvider = provider + } + } + } + if (!targetGraphNode) { + const provider = modelToNodeStore.getNodeProvider(model.directory) + if (provider) { + targetGraphNode = comfyApp.addNodeOnGraph(provider.nodeDef, { + pos + }) + targetProvider = provider + } + } + if (targetGraphNode) { + const widget = targetGraphNode.widgets.find( + (widget) => widget.name === targetProvider.key ) if (widget) { widget.value = model.name diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index 038d5ec6a..411e73702 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -1,7 +1,6 @@ import { ComfyNodeDefImpl } from '@/stores/nodeDefStore' import { useNodeDefStore } from '@/stores/nodeDefStore' import { defineStore } from 'pinia' -import { toRaw } from 'vue' /** Helper class that defines how to construct a node from a model. */ export class ModelNodeProvider { @@ -20,7 +19,7 @@ export class ModelNodeProvider { /** Service for mapping model types (by folder name) to nodes. */ export const useModelToNodeStore = defineStore('modelToNode', { state: () => ({ - modelToNodeMap: {} as Record, + modelToNodeMap: {} as Record, nodeDefStore: useNodeDefStore(), haveDefaultsLoaded: false }), @@ -31,6 +30,16 @@ export const useModelToNodeStore = defineStore('modelToNode', { * @returns The node provider for the given model type name. */ getNodeProvider(modelType: string): ModelNodeProvider { + this.registerDefaults() + return this.modelToNodeMap[modelType]?.[0] + }, + + /** + * Get the list of all valid node providers for the given model type name. + * @param modelType The name of the model type to get the node providers for. + * @returns The list of all valid node providers for the given model type name. + */ + getAllNodeProviders(modelType: string): ModelNodeProvider[] { this.registerDefaults() return this.modelToNodeMap[modelType] }, @@ -42,7 +51,21 @@ export const useModelToNodeStore = defineStore('modelToNode', { */ registerNodeProvider(modelType: string, nodeProvider: ModelNodeProvider) { this.registerDefaults() - this.modelToNodeMap[modelType] = nodeProvider + this.modelToNodeMap[modelType] ??= [] + this.modelToNodeMap[modelType].push(nodeProvider) + }, + + /** + * Register a node provider for the given simple names. + * @param modelType The name of the model type to register the node provider for. + * @param nodeClass The node class name to register. + * @param key The key to use for the node input. + */ + quickRegister(modelType: string, nodeClass: string, key: string) { + this.registerNodeProvider( + modelType, + new ModelNodeProvider(this.nodeDefStore.nodeDefsByName[nodeClass], key) + ) }, registerDefaults() { @@ -53,34 +76,16 @@ export const useModelToNodeStore = defineStore('modelToNode', { return } this.haveDefaultsLoaded = true - this.registerNodeProvider( + this.quickRegister('checkpoints', 'CheckpointLoaderSimple', 'ckpt_name') + this.quickRegister( 'checkpoints', - new ModelNodeProvider( - this.nodeDefStore.nodeDefsByName['CheckpointLoaderSimple'], - 'ckpt_name' - ) - ) - this.registerNodeProvider( - 'loras', - new ModelNodeProvider( - this.nodeDefStore.nodeDefsByName['LoraLoader'], - 'lora_name' - ) - ) - this.registerNodeProvider( - 'vae', - new ModelNodeProvider( - this.nodeDefStore.nodeDefsByName['VAELoader'], - 'vae_name' - ) - ) - this.registerNodeProvider( - 'controlnet', - new ModelNodeProvider( - this.nodeDefStore.nodeDefsByName['ControlNetLoader'], - 'control_net_name' - ) + 'ImageOnlyCheckpointLoader', + 'ckpt_name' ) + this.quickRegister('loras', 'LoraLoader', 'lora_name') + this.quickRegister('loras', 'LoraLoaderModelOnly', 'lora_name') + this.quickRegister('vae', 'VAELoader', 'vae_name') + this.quickRegister('controlnet', 'ControlNetLoader', 'control_net_name') } } })