Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add multi-output support #396

Merged
merged 2 commits into from
Dec 11, 2023
Merged

feat: add multi-output support #396

merged 2 commits into from
Dec 11, 2023

Conversation

fd0r
Copy link
Collaborator

@fd0r fd0r commented Nov 24, 2023

@cla-bot cla-bot bot added the cla-signed label Nov 24, 2023
@fd0r fd0r force-pushed the multi_output branch 2 times, most recently from 5f5a4de to 1a7c302 Compare November 24, 2023 13:44
Makefile Outdated Show resolved Hide resolved
@fd0r fd0r force-pushed the multi_output branch 15 times, most recently from 82d7d9f to 6fb8011 Compare November 29, 2023 18:41
benchmarks/deep_learning.py Outdated Show resolved Hide resolved
@fd0r fd0r force-pushed the multi_output branch 7 times, most recently from 30aa027 to 9543d4b Compare December 5, 2023 15:56
@fd0r fd0r marked this pull request as ready for review December 5, 2023 17:49
@fd0r fd0r requested a review from a team as a code owner December 5, 2023 17:49
RomanBredehoft
RomanBredehoft previously approved these changes Dec 7, 2023
Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good ! I only have a few open questions remaining, but would be fine for me if we merge it like this , again, thanks a lot for this integration !

@@ -1116,9 +1115,11 @@ def quantize_input(self, X: numpy.ndarray) -> numpy.ndarray:
assert isinstance(q_X, numpy.ndarray)
return q_X

def dequantize_output(self, q_y_preds: numpy.ndarray) -> numpy.ndarray:
def dequantize_output(self, *q_y_preds: numpy.ndarray) -> numpy.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized here, we have a signature change here : dequantize_output is a common method between all built-in models, this means it 'inherits' from the BaseEstimator method, which has the following signature / docstring :

    @abstractmethod
    def dequantize_output(self, q_y_preds: numpy.ndarray) -> numpy.ndarray:
        """De-quantize the output.

        This step ensures that the fit method has been called.

        Args:
            q_y_preds (numpy.ndarray): The quantized output values to de-quantize.

        Returns:
            numpy.ndarray: The de-quantized output values.
        """

what should we do then here ? I feel like this signature change is necessary (I guess QNNs could be multi-output right ?) but dequantize_output for linear / tree models should stay like it is right now

I'm actually surprised mypy did not complain here, I thought it would point out that the method's signature has changed .. ?

so yes, not sure what could be done here, it probably looks fine to just keep it liks this but maybe add a new docstring to better indicate that the signature is a bit different ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So actually if I understand correctly def f(arg) and def f(*arg) aren't incompatible since one is just a superset of the other.

I am not sure if it makes sense to change it for builtin models too, but a relevant question indeed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if def f(*arg) is properly considered a super-set of def f(arg) in signatures then I think it's fine !

@@ -34,7 +34,7 @@ readme = "README.md"
# Investigate if it is better to fix specific versions or use lower and upper bounds
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2665
python = ">=3.8.1,<3.11"
concrete-python = "2.5.0-rc1 "
#concrete-python = "2023.11.5"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's the main issue with nightly, I believe we don't want them to be public, only rc should

tests/torch/test_compile_torch.py Show resolved Hide resolved
RomanBredehoft
RomanBredehoft previously approved these changes Dec 7, 2023
Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, all is good for me !

RomanBredehoft
RomanBredehoft previously approved these changes Dec 7, 2023
for (_, layer) in self.quant_layers_dict.values():
layer.debug_value_tracker = None
return result, debug_value_tracker
# De-quantize the output predicted values
y_pred = self.dequantize_output(*to_tuple(q_y_pred))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can keep the debug values here as well, the computation is still done on integers

Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry but if we add new flaky marks to pytest, better to create an issue + add a fixme along these marks !

@fd0r
Copy link
Collaborator Author

fd0r commented Dec 8, 2023

Indeed my bad, I'll create the issue and add the FIXME.

Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for this !

Copy link

Coverage passed ✅

Coverage details

---------- coverage: platform linux, python 3.8.18-final-0 -----------
Name    Stmts   Miss  Cover   Missing
-------------------------------------
TOTAL    6430      0   100%

51 files skipped due to complete coverage.

Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@fd0r fd0r merged commit 4f67883 into main Dec 11, 2023
9 checks passed
@fd0r fd0r deleted the multi_output branch December 11, 2023 15:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants