Skip to content

A MLX port of FLUX based on the Huggingface Diffusers implementation.

License

Notifications You must be signed in to change notification settings

filipstrand/mflux

Repository files navigation

image A MLX port of FLUX based on the Huggingface Diffusers implementation.

About

Run the powerful FLUX models from Black Forest Labs locally on your Mac!

Table of contents

Philosophy

MFLUX is a line-by-line port of the FLUX implementation in the Huggingface Diffusers library to Apple MLX. MFLUX is purposefully kept minimal and explicit - Network architectures are hardcoded and no config files are used except for the tokenizers. The aim is to have a tiny codebase with the single purpose of expressing these models (thereby avoiding too many abstractions). While MFLUX priorities readability over generality and performance, it can still be quite fast, and even faster quantized.

All models are implemented from scratch in MLX and only the tokenizers are used via the Huggingface Transformers library. Other than that, there are only minimal dependencies like Numpy and Pillow for simple image post-processing.

💿 Installation

For users, the easiest way to install MFLUX is to use uv tool: If you have installed uv, simply:

uv tool install --upgrade mflux

to get the mflux-generate and related command line executables. You can skip to the usage guides below.

For Python 3.13 dev preview

The T5 encoder is dependent on sentencepiece, which does not have a installable wheel artifact for Python 3.13 as of Nov 2024. Until Google publishes a 3.13 wheel, you need to build your own wheel with official build instructions or for your convenience use a .whl pre-built by contributor @anthonywu. The steps below should work for most developers though your system may vary.

uv venv --python 3.13
python -V  # e.g. Python 3.13.0rc2
source .venv/bin/activate

# for your convenience, you can use the contributor wheel
uv pip install https://github.com/anthonywu/sentencepiece/releases/download/0.2.1-py13dev/sentencepiece-0.2.1-cp313-cp313-macosx_11_0_arm64.whl

# enable the pytorch nightly 
uv pip install --pre --extra-index-url https://download.pytorch.org/whl/nightly -e .
For the classic way to create a user virtual environment:
mkdir -p mflux && cd mflux && python3 -m venv .venv && source .venv/bin/activate

This creates and activates a virtual environment in the mflux folder. After that, install MFLUX via pip:

pip install -U mflux
For contributors (click to expand)
  1. Clone the repo:
 git clone git@github.com:filipstrand/mflux.git
  1. Install the application
 make install
  1. To run the test suite
 make test
  1. Follow format and lint checks prior to submitting Pull Requests. The recommended make lint and make format installs and uses ruff. You can setup your editor/IDE to lint/format automatically, or use our provided make helpers:
  • make format - formats your code
  • make lint - shows your lint errors and warnings, but does not auto fix
  • make check - via pre-commit hooks, formats your code and attempts to auto fix lint errors
  • consult official ruff documentation on advanced usages

If you have trouble installing MFLUX, please see the installation related issues section.

🖼️ Generating an image

Run the command mflux-generate by specifying a prompt and the model and some optional arguments. For example, here we use a quantized version of the schnell model for 2 steps:

mflux-generate --model schnell --prompt "Luxury food photograph" --steps 2 --seed 2 -q 8

This example uses the more powerful dev model with 25 time steps:

mflux-generate --model dev --prompt "Luxury food photograph" --steps 25 --seed 2 -q 8

⚠️ If the specific model is not already downloaded on your machine, it will start the download process and fetch the model weights (~34GB in size for the Schnell or Dev model respectively). See the quantization section for running compressed versions of the model. ⚠️

By default, model files are downloaded to the .cache folder within your home directory. For example, in my setup, the path looks like this:

/Users/filipstrand/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-dev

To change this default behavior, you can do so by modifying the HF_HOME environment variable. For more details on how to adjust this setting, please refer to the Hugging Face documentation.

🔒 FLUX.1-dev currently requires granted access to its Huggingface repo. For troubleshooting, see the issue tracker 🔒

📜 Full list of Command-Line Arguments

  • --prompt (required, str): Text description of the image to generate.

  • --model or -m (required, str): Model to use for generation ("schnell" or "dev").

  • --output (optional, str, default: "image.png"): Output image filename.

  • --seed (optional, int, default: None): Seed for random number generation. Default is time-based.

  • --height (optional, int, default: 1024): Height of the output image in pixels.

  • --width (optional, int, default: 1024): Width of the output image in pixels.

  • --steps (optional, int, default: 4): Number of inference steps.

  • --guidance (optional, float, default: 3.5): Guidance scale (only used for "dev" model).

  • --path (optional, str, default: None): Path to a local model on disk.

  • --quantize or -q (optional, int, default: None): Quantization (choose between 4 or 8).

  • --lora-paths (optional, [str], default: None): The paths to the LoRA weights.

  • --lora-scales (optional, [float], default: None): The scale for each respective LoRA (will default to 1.0 if not specified and only one LoRA weight is loaded.)

  • --metadata (optional): Exports a .json file containing the metadata for the image with the same name. (Even without this flag, the image metadata is saved and can be viewed using exiftool image.png)

  • --controlnet-image-path (required, str): Path to the local image used by ControlNet to guide output generation.

  • --controlnet-strength (optional, float, default: 0.4): Degree of influence the control image has on the output. Ranges from 0.0 (no influence) to 1.0 (full influence).

  • --controlnet-save-canny (optional, bool, default: False): If set, saves the Canny edge detection reference image used by ControlNet.

  • --init-image-path (optional, str, default: None): Local path to the initial image for image-to-image generation.

  • --init-image-strength (optional, float, default: 0.4): Controls how strongly the initial image influences the output image. A value of 0.0 means no influence. (Default is 0.4)

  • --config-from-metadata or -C (optional, str): [EXPERIMENTAL] Path to a prior file saved via --metadata, or a compatible handcrafted config file adhering to the expected args schema.

parameters supported by config files

How configs are used

  • all config properties are optional and applied to the image generation if applicable
  • invalid or incompatible properties will be ignored

Config schema

{
  "$schema": "http://json-schema.org/draft-07/schema#",
  "type": "object",
  "properties": {
    "seed": {
      "type": ["integer", "null"]
    },
    "steps": {
      "type": ["integer", "null"]
    },
    "guidance": {
      "type": ["number", "null"]
    },
    "quantize": {
      "type": ["null", "string"]
    },
    "lora_paths": {
      "type": ["array", "null"],
      "items": {
        "type": "string"
      }
    },
    "lora_scales": {
      "type": ["array", "null"],
      "items": {
        "type": "number"
      }
    },
    "prompt": {
      "type": ["string", "null"]
    }
  }
}

Example

{
  "model": "dev",
  "seed": 42,
  "steps": 8,
  "guidance": 3.0,
  "quantize": 4,
  "lora_paths": [
    "/some/path1/to/subject.safetensors",
    "/some/path2/to/style.safetensors"
  ],
  "lora_scales": [
    0.8,
    0.4
  ],
  "prompt": "award winning modern art, MOMA"
}

Or, with the correct python environment active, create and run a separate script like the following:

from mflux import Flux1, Config

# Load the model
flux = Flux1.from_alias(
   alias="schnell",  # "schnell" or "dev"
   quantize=8,       # 4 or 8
)

# Generate an image
image = flux.generate_image(
   seed=2,
   prompt="Luxury food photograph",
   config=Config(
      num_inference_steps=2,  # "schnell" works well with 2-4 steps, "dev" works well with 20-25 steps
      height=1024,
      width=1024,
   )
)

image.save(path="image.png")

For more options on how to configure MFLUX, please see generate.py.

⏱️ Image generation speed (updated)

These numbers are based on the non-quantized schnell model, with the configuration provided in the code snippet below. To time your machine, run the following:

time mflux-generate \
--prompt "Luxury food photograph" \
--model schnell \
--steps 2 \
--seed 2 \
--height 1024 \
--width 1024

To find out the spec of your machine (including number of CPU cores, GPU cores, and memory, run the following command:

system_profiler SPHardwareDataType SPDisplaysDataType
Device M-series User Reported Time Notes
Mac Studio 2023 M2 Ultra @awni <15s
Macbook Pro 2024 M4 Max (128GB) @ivanfioravanti ~19s
Macbook Pro 2023 M3 Max @karpathy ~20s
- 2023 M2 Max (96GB) @explorigin ~25s
Mac Mini 2024 M4 Pro (64GB) @Stoobs ~34s
Mac Mini 2023 M2 Pro (32GB) @leekichko ~54s
- 2022 M1 MAX (64GB) @BosseParra ~55s
Macbook Pro 2023 M2 Max (32GB) @filipstrand ~70s
- 2023 M3 Pro (36GB) @kush-gupt ~80s
Mac Mini 2024 M4 (16GB) @wnma3mz ~97s 512 x 512, 8-bit quantization
Macbook Pro 2021 M1 Pro (32GB) @filipstrand ~160s
- 2021 M1 Pro (16GB) @qw-in ~175s Might freeze your mac
Macbook Air 2020 M1 (8GB) @mbvillaverde ~335s With resolution 512 x 512

Note that these numbers includes starting the application from scratch, which means doing model i/o, setting/quantizing weights etc. If we assume that the model is already loaded, you can inspect the image metadata using exiftool image.png and see the total duration of the denoising loop (excluding text embedding).

These benchmarks are not very scientific and is only intended to give ballpark numbers. They were performed during different times with different MFLUX and MLX-versions etc. Additional hardware information such as number of GPU cores, Mac device etc. are not always known.

↔️ Equivalent to Diffusers implementation

There is only a single source of randomness when generating an image: The initial latent array. In this implementation, this initial latent is fully deterministically controlled by the input seed parameter. However, if we were to import a fixed instance of this latent array saved from the Diffusers implementation, then MFLUX will produce an identical image to the Diffusers implementation (assuming a fixed prompt and using the default parameter settings in the Diffusers setup).

The images below illustrate this equivalence. In all cases the Schnell model was run for 2 time steps. The Diffusers implementation ran in CPU mode. The precision for MFLUX can be set in the Config class. There is typically a noticeable but very small difference in the final image when switching between 16bit and 32bit precision.


Luxury food photograph

image


detailed cinematic dof render of an old dusty detailed CRT monitor on a wooden desk in a dim room with items around, messy dirty room. On the screen are the letters "FLUX" glowing softly. High detail hard surface render

image


photorealistic, lotr, A tiny red dragon curled up asleep inside a nest, (Soft Focus) , (f_stop 2.8) , (focal_length 50mm) macro lens f/2. 8, medieval wizard table, (pastel) colors, (cozy) morning light filtering through a nearby window, (whimsical) steam shapes, captured with a (Canon EOS R5) , highlighting (serene) comfort, medieval, dnd, rpg, 3d, 16K, 8K

image


A weathered fisherman in his early 60s stands on the deck of his boat, gazing out at a stormy sea. He has a thick, salt-and-pepper beard, deep-set blue eyes, and skin tanned and creased from years of sun exposure. He's wearing a yellow raincoat and hat, with water droplets clinging to the fabric. Behind him, dark clouds loom ominously, and waves crash against the side of the boat. The overall atmosphere is one of tension and respect for the power of nature.

image


Luxury food photograph of an italian Linguine pasta alle vongole dish with lots of clams. It has perfect lighting and a cozy background with big bokeh and shallow depth of field. The mood is a sunset balcony in tuscany.  The photo is taken from the side of the plate. The pasta is shiny with sprinkled parmesan cheese and basil leaves on top. The scene is complemented by a warm, inviting light that highlights the textures and colors of the ingredients, giving it an appetizing and elegant look.

image


🗜️ Quantization

MFLUX supports running FLUX in 4-bit or 8-bit quantized mode. Running a quantized version can greatly speed up the generation process and reduce the memory consumption by several gigabytes. Quantized models also take up less disk space.

mflux-generate \
    --model schnell \
    --steps 2 \
    --seed 2 \
    --quantize 8 \
    --height 1920 \
    --width 1024 \
    --prompt "Tranquil pond in a bamboo forest at dawn, the sun is barely starting to peak over the horizon, panda practices Tai Chi near the edge of the pond, atmospheric perspective through the mist of morning dew, sunbeams, its movements are graceful and fluid — creating a sense of harmony and balance, the pond’s calm waters reflecting the scene, inviting a sense of meditation and connection with nature, style of Howard Terpning and Jessica Rossier"

image

In this example, weights are quantized at runtime - this is convenient if you don't want to save a quantized copy of the weights to disk, but still want to benefit from the potential speedup and RAM reduction quantization might bring.

By selecting the --quantize or -q flag to be 4, 8, or removing it entirely, we get all 3 images above. As can be seen, there is very little difference between the images (especially between the 8-bit, and the non-quantized result). Image generation times in this example are based on a 2021 M1 Pro (32GB) machine. Even though the images are almost identical, there is a ~2x speedup by running the 8-bit quantized version on this particular machine. Unlike the non-quantized version, for the 8-bit version the swap memory usage is drastically reduced and GPU utilization is close to 100% during the whole generation. Results here can vary across different machines.

📊 Size comparisons for quantized models

The model sizes for both schnell and dev at various quantization levels are as follows:

4 bit 8 bit Original (16 bit)
9.85GB 18.16GB 33.73GB

The reason weights sizes are not fully cut in half is because a small number of weights are not quantized and kept at full precision.

💾 Saving a quantized version to disk

To save a local copy of the quantized weights, run the mflux-save command like so:

mflux-save \
    --path "/Users/filipstrand/Desktop/schnell_8bit" \
    --model schnell \
    --quantize 8

Note that when saving a quantized version, you will need the original huggingface weights.

It is also possible to specify LoRA adapters when saving the model, e.g

mflux-save \
    --path "/Users/filipstrand/Desktop/schnell_8bit" \
    --model schnell \
    --quantize 8 \
    --lora-paths "/path/to/lora.safetensors" \
    --lora-scales 0.7

When generating images with a model like this, no LoRA adapter is needed to be specified since it is already baked into the saved quantized weights.

💽 Loading and running a quantized version from disk

To generate a new image from the quantized model, simply provide a --path to where it was saved:

mflux-generate \
    --path "/Users/filipstrand/Desktop/schnell_8bit" \
    --model schnell \
    --steps 2 \
    --seed 2 \
    --height 1920 \
    --width 1024 \
    --prompt "Tranquil pond in a bamboo forest at dawn, the sun is barely starting to peak over the horizon, panda practices Tai Chi near the edge of the pond, atmospheric perspective through the mist of morning dew, sunbeams, its movements are graceful and fluid — creating a sense of harmony and balance, the pond’s calm waters reflecting the scene, inviting a sense of meditation and connection with nature, style of Howard Terpning and Jessica Rossier"

Note: When loading a quantized model from disk, there is no need to pass in -q flag, since we can infer this from the weight metadata.

Also Note: Once we have a local model (quantized or not) specified via the --path argument, the huggingface cache models are not required to launch the model. In other words, you can reclaim the 34GB diskspace (per model) by deleting the full 16-bit model from the Huggingface cache if you choose.

If you don't want to download the full models and quantize them yourself, the 4-bit weights are available here for a direct download:

💽 Running a non-quantized model directly from disk

MFLUX also supports running a non-quantized model directly from a custom location. In the example below, the model is placed in /Users/filipstrand/Desktop/schnell:

mflux-generate \
    --path "/Users/filipstrand/Desktop/schnell" \
    --model schnell \
    --steps 2 \
    --seed 2 \
    --prompt "Luxury food photograph"

Note that the --model flag must be set when loading a model from disk.

Also note that unlike when using the typical alias way of initializing the model (which internally handles that the required resources are downloaded), when loading a model directly from disk, we require the downloaded models to look like the following:

.
├── text_encoder
│   └── model.safetensors
├── text_encoder_2
│   ├── model-00001-of-00002.safetensors
│   └── model-00002-of-00002.safetensors
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── tokenizer_2
│   ├── special_tokens_map.json
│   ├── spiece.model
│   ├── tokenizer.json
│   └── tokenizer_config.json
├── transformer
│   ├── diffusion_pytorch_model-00001-of-00003.safetensors
│   ├── diffusion_pytorch_model-00002-of-00003.safetensors
│   └── diffusion_pytorch_model-00003-of-00003.safetensors
└── vae
    └── diffusion_pytorch_model.safetensors

This mirrors how the resources are placed in the HuggingFace Repo for FLUX.1. Huggingface weights, unlike quantized ones exported directly from this project, have to be processed a bit differently, which is why we require this structure above.


🎨 Image-to-Image

One way to condition the image generation is by starting from an existing image and let MFLUX produce new variations. Use the --init-image-path flag to specify the reference image, and the --init-image-strength to control how much the reference image should guide the generation. For example, given the reference image below, the following command produced the first image using the Sketching LoRA:

mflux-generate \
--prompt "sketching of an Eiffel architecture, masterpiece, best quality. The site is lit by lighting professionals, creating a subtle illumination effect. Ink on paper with very fine touches with colored markers, (shadings:1.1), loose lines, Schematic, Conceptual, Abstract, Gestural. Quick sketches to explore ideas and concepts." \
--init-image-path "reference.png" \
--init-image-strength 0.3 \
--lora-paths Architectural_Sketching.safetensors \
--lora-scales 1.0 \
--model dev \
--steps 20 \
--seed 43 \
--guidance 4.0 \
--quantize 8 \
--height 1024 \
--width 1024

Like with Controlnet, this technique combines well with LoRA adapters:

image

In the examples above the following LoRAs are used Sketching, Animation Shot and flux-film-camera are used.


🔌 LoRA

MFLUX support loading trained LoRA adapters (actual training support is coming).

The following example The_Hound LoRA from @TheLastBen:

mflux-generate --prompt "sandor clegane" --model dev --steps 20 --seed 43 -q 8 --lora-paths "sandor_clegane_single_layer.safetensors"

image

The following example is Flux_1_Dev_LoRA_Paper-Cutout-Style LoRA from @Norod78:

mflux-generate --prompt "pikachu, Paper Cutout Style" --model schnell --steps 4 --seed 43 -q 8 --lora-paths "Flux_1_Dev_LoRA_Paper-Cutout-Style.safetensors"

image

Note that LoRA trained weights are typically trained with a trigger word or phrase. For example, in the latter case, the sentence should include the phrase "Paper Cutout Style".

Also note that the same LoRA weights can work well with both the schnell and dev models. Refer to the original LoRA repository to see what mode it was trained for.

Multi-LoRA

Multiple LoRAs can be sent in to combine the effects of the individual adapters. The following example combines both of the above LoRAs:

mflux-generate \
   --prompt "sandor clegane in a forest, Paper Cutout Style" \
   --model dev \
   --steps 20 \
   --seed 43 \
   --lora-paths sandor_clegane_single_layer.safetensors Flux_1_Dev_LoRA_Paper-Cutout-Style.safetensors \
   --lora-scales 1.0 1.0 \
   -q 8

image

Just to see the difference, this image displays the four cases: One of having both adapters fully active, partially active and no LoRA at all. The example above also show the usage of --lora-scales flag.

Supported LoRA formats (updated)

Since different fine-tuning services can use different implementations of FLUX, the corresponding LoRA weights trained on these services can be different from one another. The aim of MFLUX is to support the most common ones. The following table show the current supported formats:

Supported Name Example Notes
BFL civitai - Impressionism Many things on civitai seem to work
Diffusers Flux_1_Dev_LoRA_Paper-Cutout-Style
XLabs-AI flux-RealismLora

To report additional formats, examples or other any suggestions related to LoRA format support, please see issue #47.


🎛 Dreambooth fine-tuning

As of release v.0.5.0, MFLUX has support for fine-tuning your own LoRA adapters using the Dreambooth technique.

image

This example shows the MFLUX training progression of the included training example which is based on the DreamBooth Dataset, also used in the mlx-examples repo.

Training configuration

To describe a training run, you need to provide a training configuration file which specifies the details such as what training data to use and various parameters. To try it out, one of the easiest ways is to start from the provided example configuration and simply use your own dataset and prompts by modifying the examples section of the json file.

Training example

A complete example (training configuration + dataset) is provided in this repository. To start a training run, go to the project folder cd mflux, and simply run:

mflux-train --train-config src/mflux/dreambooth/_example/train.json

By default, this will train an adapter with images of size 512x512 with a batch size of 1 and can take up to several hours to fully complete depending on your machine. If this task is too computationally demanding, see the section on memory issues for tips on how to speed things up and what tradeoffs exist.

During training, MFLUX will output training checkpoints with artifacts (weights, states) according to what is specified in the configuration file. As specified in the file train.json, these files will be placed in a folder on the Desktop called ~/Desktop/train, but this can of course be changed to any other path by adjusting the configuration. All training artifacts will be saved as self-contained zip file, which can later be pointed to resume an existing training run. To find the LoRA weights, simply unzip and look for the adapter safetensors file and use it as you would with a regular downloaded LoRA adapter.

Resuming a training run

The training process will continue to run until each training example has been used num_epochs times. For various reasons however, the user might choose to interrupt the process. To resume training for a given checkpoint, say 0001000_checkpoint.zip, simply run:

mflux-train --train-checkpoint 0001000_checkpoint.zip

There are two nice properties of the training procedure:

  • Fully deterministic (given a specified seed in the training configuration)
  • The complete training state (including optimizer state) is saved at each checkpoint.

Because of these, MFLUX has the ability to resume a training run from a previous checkpoint and have the results be exactly identical to a training run which was never interrupted in the first place.

⚠️ Note: Everything but the dataset itself is contained within this zip file, as the dataset can be quite large. The zip file will contain configuration files which point to the original dataset, so make sure that it is in the same place when resuming training.

⚠️ Note: One current limitation is that a training run can only be resumed if it has not yet been completed. In other words, only checkpoints that represent an interrupted training-run can be resumed and run until completion.

Configuration details

Currently, MFLUX supports fine-tuning only for the transformer part of the model. In the training configuration, under lora_layers, you can specify which layers you want to train. The available ones are:

  • transformer_blocks:
    • attn.to_q
    • attn.to_k
    • attn.to_v
    • attn.add_q_proj
    • attn.add_k_proj
    • attn.add_v_proj
    • attn.to_out
    • attn.to_add_out
    • ff.linear1
    • ff.linear2
    • ff_context.linear1
    • ff_context.linear2
  • single_transformer_blocks:
    • proj_out
    • proj_mlp
    • attn.to_q
    • attn.to_k
    • attn.to_v

The block_range under the respective layer category specifies which blocks to train. The maximum range available for the different layer categories are:

  • transformer_blocks:
    • start: 0
    • end: 19
  • single_transformer_blocks:
    • start: 0
    • end: 38

⚠️ Note: As the joint transformer blocks (transformer_blocks) - are placed earlier on in the sequence of computations, they will require more resources to train. In other words, training later layers, such as only the single_transformer_blocks should be faster. However, training too few / only later layers might result in a faster but unsuccessful training.

Under the examples section, there is an argument called "path" which specifies where the images are located. This path is relative to the config file itself.

Memory issues

Depending on the configuration of the training setup, fine-tuning can be quite memory intensive. In the worst case, if your Mac runs out of memory it might freeze completely and crash!

To avoid this, consider some of the following strategies to reduce memory requirements by adjusting the parameters in the training configuration:

  • Use a quantized based model by setting "quantize": 4 or "quantize": 8
  • For the layer_types, consider skipping some of the trainable layers (e.g. by not including proj_out etc.)
  • Use a lower rank value for the LoRA matrices.
  • Don't train all the 38 layers from single_transformer_blocks or all of the 19 layers from transformer_blocks
  • Use a smaller batch size, for example "batch_size": 1
  • Make sure your Mac is not busy with other background tasks that holds memory.

Applying some of these strategies, like how train.json is set up by default, will allow a 32GB M1 Pro to perform a successful fine-tuning run. Note, however, that reducing the trainable parameters might lead to worse performance.

Additional techniques such as gradient checkpoint and other strategies might be implemented in the future.

Misc

This feature is currently v1 and can be considered a bit experimental. Interfaces might change (configuration file setup etc.) The aim is to also gradually expand the scope of this feature with alternative techniques, data augmentation etc.

  • As with loading external LoRA adapters, the MFLUX training currently only supports training the transformer part of the network.
  • Sometimes, a model trained with the dev model might actually work better when applied to the schnell weights.
  • Currently, all training images are assumed to be in the resolution specified in the configuration file.
  • Loss curve can be a bit misleading/hard to read, sometimes it conveys little improvement over time, but actual image samples show the real progress.
  • When plotting the loss during training, we label it as "validation loss" but it is actually only the first 10 elements of the training examples for now. Future updates should support user inputs of separate validation images.
  • Training also works with the original model as quantized!
  • For the curious, a motivation for the loss function can be found here.
  • Two great resources that heavily inspired this feature are:

🕹️ Controlnet

MFLUX has Controlnet support for an even more fine-grained control of the image generation. By providing a reference image via --controlnet-image-path and a strength parameter via --controlnet-strength, you can guide the generation toward the reference image.

mflux-generate-controlnet \
  --prompt "A comic strip with a joker in a purple suit" \
  --model dev \
  --steps 20 \
  --seed 1727047657 \
  --height 1066 \
  --width 692 \
  -q 8 \
  --lora-paths "Dark Comic - s0_8 g4.safetensors" \
  --controlnet-image-path "reference.png" \
  --controlnet-strength 0.5 \
  --controlnet-save-canny

image

This example combines the controlnet reference image with the LoRA Dark Comic Flux.

⚠️ Note: Controlnet requires an additional one-time download of ~3.58GB of weights from Huggingface. This happens automatically the first time you run the generate-controlnet command. At the moment, the Controlnet used is InstantX/FLUX.1-dev-Controlnet-Canny, which was trained for the dev model. It can work well with schnell, but performance is not guaranteed.

⚠️ Note: The output can be highly sensitive to the controlnet strength and is very much dependent on the reference image. Too high settings will corrupt the image. A recommended starting point a value like 0.4 and to play around with the strength.

Controlnet can also work well together with LoRA adapters. In the example below the same reference image is used as a controlnet input with different prompts and LoRA adapters active.

image

🚧 Current limitations

  • Images are generated one by one.
  • Negative prompts not supported.
  • LoRA weights are only supported for the transformer part of the network.
  • Some LoRA adapters does not work.
  • Currently, the supported controlnet is the canny-only version.
  • Dreambooth training currently does not support sending in training parameters as flags.

💡Workflow Tips

  • To hide the model fetching status progress bars, export HF_HUB_DISABLE_PROGRESS_BARS=1
  • Use config files to save complex job parameters in a file instead of passing many --args
  • Set up shell aliases for required args examples:
    • shortcut for dev model: alias mflux-dev='mflux-generate --model dev'
    • shortcut for schnell model and always save metadata: alias mflux-schnell='mflux-generate --model schnell --metadata'

✅ TODO

🔬 Cool research / features to support

🌱‍ Related projects

License

This project is licensed under the MIT License.