diff --git a/migrations/versions/2024_05_28_1846-415a8f855e14_multi_colors.py b/migrations/versions/2024_05_28_1846-415a8f855e14_multi_colors.py new file mode 100644 index 000000000..30c1dbc7a --- /dev/null +++ b/migrations/versions/2024_05_28_1846-415a8f855e14_multi_colors.py @@ -0,0 +1,31 @@ +"""multi_colors. + +Revision ID: 415a8f855e14 +Revises: 395d560284b3 +Create Date: 2024-05-28 18:46:12.935449 +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "415a8f855e14" +down_revision = "395d560284b3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Perform the upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("filament", sa.Column("multi_color_hexes", sa.String(length=128), nullable=True)) + op.add_column("filament", sa.Column("multi_color_direction", sa.String(length=16), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Perform the downgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("filament", "multi_color_direction") + op.drop_column("filament", "multi_color_hexes") + # ### end Alembic commands ### diff --git a/spoolman/api/v1/filament.py b/spoolman/api/v1/filament.py index 555ae8453..43419ac0b 100644 --- a/spoolman/api/v1/filament.py +++ b/spoolman/api/v1/filament.py @@ -7,10 +7,10 @@ from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy.ext.asyncio import AsyncSession -from spoolman.api.v1.models import Filament, FilamentEvent, Message +from spoolman.api.v1.models import Filament, FilamentEvent, Message, MultiColorDirection from spoolman.database import filament from spoolman.database.database import get_db_session from spoolman.database.utils import SortOrder @@ -86,9 +86,26 @@ class FilamentParameters(BaseModel): ) color_hex: Optional[str] = Field( None, - description="Hexadecimal color code of the filament, e.g. FF0000 for red. Supports alpha channel at the end.", + description=( + "Hexadecimal color code of the filament, e.g. FF0000 for red. Supports alpha channel at the end. " + "If it's a multi-color filament, the multi_color_hexes field is used instead." + ), examples=["FF0000"], ) + multi_color_hexes: Optional[str] = Field( + None, + description=( + "Hexadecimal color code of the filament, e.g. FF0000 for red. Supports alpha channel at the end. " + "Specifying multiple colors separated by commas. " + "Also set the multi_color_direction field if you specify multiple colors." + ), + examples=["FF0000,00FF00,0000FF"], + ) + multi_color_direction: Optional[MultiColorDirection] = Field( + None, + description=("Type of multi-color filament. Only set if the color_hex field contains multiple colors. "), + examples=["coaxial", "longitudinal"], + ) external_id: Optional[str] = Field( None, max_length=256, @@ -104,24 +121,58 @@ class FilamentParameters(BaseModel): @field_validator("color_hex") @classmethod - @classmethod def color_hex_validator(cls, v: Optional[str]) -> Optional[str]: # noqa: ANN102 """Validate the color_hex field.""" if not v: return None - if v.startswith("#"): - v = v[1:] - v = v.upper() - for c in v: + clr = v.upper() + if clr.startswith("#"): + clr = clr[1:] + + for c in clr: if c not in "0123456789ABCDEF": raise ValueError("Invalid character in color code.") - if len(v) not in (6, 8): + if len(clr) not in (6, 8): raise ValueError("Color code must be 6 or 8 characters long.") return v + @field_validator("multi_color_hexes") + @classmethod + def multi_color_hexes_validator(cls, v: Optional[str]) -> Optional[str]: # noqa: ANN102 + """Validate the multi_color_hexes field.""" + if not v: + return None + for clr_raw in v.split(","): + clr = clr_raw.upper() + if clr.startswith("#"): + clr = clr[1:] + + for c in clr: + if c not in "0123456789ABCDEF": + raise ValueError("Invalid character in color code.") + + if len(clr) not in (6, 8): + raise ValueError("Color code must be 6 or 8 characters long.") + + return v + + @model_validator(mode="after") # type: ignore[] + def validate(self) -> "FilamentParameters": + """Validate the model.""" + if self.color_hex and self.multi_color_hexes: + raise ValueError("Cannot specify both color_hex and multi_color_hexes.") + if self.multi_color_hexes and len(self.multi_color_hexes.split(",")) < 2: # noqa: PLR2004 + raise ValueError("Must specify at least two colors in multi_color_hexes.") + if self.multi_color_hexes and not self.multi_color_direction: + raise ValueError("Multi-color filament must have multi_color_direction set.") + if not self.multi_color_hexes and self.multi_color_direction: + raise ValueError("Single-color filament must not have multi_color_direction set.") + + return self + class FilamentUpdateParameters(FilamentParameters): density: Optional[float] = Field(None, gt=0, description="The density of this filament in g/cm3.", examples=[1.24]) @@ -394,6 +445,8 @@ async def create( # noqa: ANN201 settings_extruder_temp=body.settings_extruder_temp, settings_bed_temp=body.settings_bed_temp, color_hex=body.color_hex, + multi_color_hexes=body.multi_color_hexes, + multi_color_direction=body.multi_color_direction, external_id=body.external_id, extra=body.extra, ) diff --git a/spoolman/api/v1/models.py b/spoolman/api/v1/models.py index d1fa5baf1..8e56af032 100644 --- a/spoolman/api/v1/models.py +++ b/spoolman/api/v1/models.py @@ -93,6 +93,13 @@ def from_db(item: models.Vendor) -> "Vendor": ) +class MultiColorDirection(Enum): + """Enum for multi-color direction.""" + + COAXIAL = "coaxial" + LONGITUDINAL = "longitudinal" + + class Filament(BaseModel): id: int = Field(description="Unique internal ID of this filament type.") registered: SpoolmanDateTime = Field(description="When the filament was registered in the database. UTC Timezone.") @@ -155,9 +162,27 @@ class Filament(BaseModel): None, min_length=6, max_length=8, - description="Hexadecimal color code of the filament, e.g. FF0000 for red. Supports alpha channel at the end.", + description=( + "Hexadecimal color code of the filament, e.g. FF0000 for red. Supports alpha channel at the end. " + "If it's a multi-color filament, the multi_color_hexes field is used instead." + ), examples=["FF0000"], ) + multi_color_hexes: Optional[str] = Field( + None, + min_length=6, + description=( + "Hexadecimal color code of the filament, e.g. FF0000 for red. Supports alpha channel at the end. " + "Specifying multiple colors separated by commas. " + "Also set the multi_color_direction field if you specify multiple colors." + ), + examples=["FF0000,00FF00,0000FF"], + ) + multi_color_direction: Optional[MultiColorDirection] = Field( + None, + description=("Type of multi-color filament. Only set if the multi_color_hexes field is set."), + examples=["coaxial", "longitudinal"], + ) external_id: Optional[str] = Field( None, max_length=256, @@ -192,6 +217,10 @@ def from_db(item: models.Filament) -> "Filament": settings_extruder_temp=item.settings_extruder_temp, settings_bed_temp=item.settings_bed_temp, color_hex=item.color_hex, + multi_color_hexes=item.multi_color_hexes, + multi_color_direction=( + MultiColorDirection(item.multi_color_direction) if item.multi_color_direction is not None else None + ), external_id=item.external_id, extra={field.key: field.value for field in item.extra}, ) diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index 21904e8d6..e4079624f 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -11,7 +11,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import contains_eager, joinedload -from spoolman.api.v1.models import EventType, Filament, FilamentEvent +from spoolman.api.v1.models import EventType, Filament, FilamentEvent, MultiColorDirection from spoolman.database import models, vendor from spoolman.database.utils import ( SortOrder, @@ -42,6 +42,8 @@ async def create( settings_extruder_temp: Optional[int] = None, settings_bed_temp: Optional[int] = None, color_hex: Optional[str] = None, + multi_color_hexes: Optional[str] = None, + multi_color_direction: Optional[MultiColorDirection] = None, external_id: Optional[str] = None, extra: Optional[dict[str, str]] = None, ) -> models.Filament: @@ -68,6 +70,8 @@ async def create( settings_extruder_temp=settings_extruder_temp, settings_bed_temp=settings_bed_temp, color_hex=color_hex, + multi_color_hexes=multi_color_hexes, + multi_color_direction=multi_color_direction.value if multi_color_direction is not None else None, external_id=external_id, extra=[models.FilamentField(key=k, value=v) for k, v in (extra or {}).items()], ) @@ -231,11 +235,18 @@ async def find_by_color( found_filaments: list[models.Filament] = [] for filament in filaments: - if filament.color_hex is None: + if filament.color_hex is not None: + colors = [filament.color_hex] + elif filament.multi_color_hexes is not None: + colors = filament.multi_color_hexes.split(",") + else: continue - color_lab = rgb_to_lab(hex_to_rgb(filament.color_hex)) - if delta_e(color_query_lab, color_lab) <= similarity_threshold: - found_filaments.append(filament) + + for color in colors: + color_lab = rgb_to_lab(hex_to_rgb(color)) + if delta_e(color_query_lab, color_lab) <= similarity_threshold: + found_filaments.append(filament) + break return found_filaments diff --git a/spoolman/database/models.py b/spoolman/database/models.py index c0dfe188a..992296983 100644 --- a/spoolman/database/models.py +++ b/spoolman/database/models.py @@ -49,6 +49,8 @@ class Filament(Base): settings_extruder_temp: Mapped[Optional[int]] = mapped_column(comment="Overridden extruder temperature.") settings_bed_temp: Mapped[Optional[int]] = mapped_column(comment="Overridden bed temperature.") color_hex: Mapped[Optional[str]] = mapped_column(String(8)) + multi_color_hexes: Mapped[Optional[str]] = mapped_column(String(128)) + multi_color_direction: Mapped[Optional[str]] = mapped_column(String(16)) external_id: Mapped[Optional[str]] = mapped_column(String(256)) extra: Mapped[list["FilamentField"]] = relationship( back_populates="filament", diff --git a/tests_integration/tests/filament/test_add.py b/tests_integration/tests/filament/test_add.py index 57cb1f7b9..b7efe31ac 100644 --- a/tests_integration/tests/filament/test_add.py +++ b/tests_integration/tests/filament/test_add.py @@ -128,3 +128,106 @@ def test_add_filament_color_hex_alpha(): # Clean up httpx.delete(f"{URL}/api/v1/filament/{filament['id']}").raise_for_status() + + +def test_add_filament_multi_color(): + """Test adding a filament with multi color hexes.""" + multi_color_hexes = "FF0000,00FF00" + multi_color_direction = "coaxial" + + # Execute + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "density": 1.25, + "diameter": 1.75, + "multi_color_hexes": multi_color_hexes, + "multi_color_direction": multi_color_direction, + }, + ) + result.raise_for_status() + + # Verify + filament = result.json() + assert filament["multi_color_hexes"] == multi_color_hexes + assert filament["multi_color_direction"] == multi_color_direction + + # Clean up + httpx.delete(f"{URL}/api/v1/filament/{filament['id']}").raise_for_status() + + +def test_add_filament_multi_color_errors(): + """Test adding a filament with multi color hexes with errors.""" + # Bad hex color list + # Execute + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "density": 1.25, + "diameter": 1.75, + "multi_color_hexes": "FF0000,", # Bad 2nd color + "multi_color_direction": "coaxial", + }, + ) + + # Verify + assert result.status_code == 422 + + # Missing multi_color_hexes but direction is specified + # Execute + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "density": 1.25, + "diameter": 1.75, + "multi_color_direction": "coaxial", + }, + ) + + # Verify + assert result.status_code == 422 + + # multi_color_hexes is specified but direction is not + # Execute + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "density": 1.25, + "diameter": 1.75, + "multi_color_hexes": "FF0000,00FF00", + }, + ) + + # Verify + assert result.status_code == 422 + + # multi_color_hexes only has a single color + # Execute + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "density": 1.25, + "diameter": 1.75, + "multi_color_hexes": "FF0000", + "multi_color_direction": "coaxial", + }, + ) + + # Verify + assert result.status_code == 422 + + # Both multi_color_hexes and color_hex is specified + # Execute + result = httpx.post( + f"{URL}/api/v1/filament", + json={ + "density": 1.25, + "diameter": 1.75, + "color_hex": "FF0000", + "multi_color_hexes": "FF0000,00FF00", + "multi_color_direction": "coaxial", + }, + ) + + # Verify + assert result.status_code == 422