diff --git a/lib/vast/Conversion/Parser/ToParser.cpp b/lib/vast/Conversion/Parser/ToParser.cpp index e94b20c8bf..80b7509864 100644 --- a/lib/vast/Conversion/Parser/ToParser.cpp +++ b/lib/vast/Conversion/Parser/ToParser.cpp @@ -258,6 +258,20 @@ namespace vast::conv { : base(mctx), models(models) {} + static std::optional< function_model > get_model( + const function_models &models, string_ref name + ) { + if (auto kv = models.find(name); kv != models.end()) { + return kv->second; + } + + return std::nullopt; + } + + std::optional< function_model > get_model(string_ref name) const { + return get_model(models, name); + } + const function_models ⊧ }; @@ -408,34 +422,60 @@ namespace vast::conv { } }; - struct FuncConversion - : parser_conversion_pattern_base< hl::FuncOp > - , tc::op_type_conversion< hl::FuncOp, function_type_converter > + struct ReturnConversion + : parser_conversion_pattern_base< hl::ReturnOp > { - using op_t = hl::FuncOp; + using op_t = hl::ReturnOp; using base = parser_conversion_pattern_base< op_t >; using base::base; using adaptor_t = typename op_t::Adaptor; - static std::optional< function_model > get_model( - const function_models &models, op_t op - ) { - if (auto kv = models.find(op.getSymName()); kv != models.end()) { - return kv->second; - } + logical_result matchAndRewrite( + op_t op, adaptor_t adaptor, conversion_rewriter &rewriter + ) const override { + auto func = op->getParentOfType< hl::FuncOp >(); + auto model = get_model(func.getSymName()); - return std::nullopt; + auto rty = model + ? model->get_return_type(rewriter.getContext()) + : pr::MaybeDataType::get(rewriter.getContext()); + + rewriter.replaceOpWithNewOp< op_t >( + op, convert_value_types(adaptor.getOperands(), rty, rewriter) + ); + + return mlir::success(); } - std::optional< function_model > get_model(op_t op) const { - return get_model(models, op); + static void legalize(parser_conversion_config &cfg) { + cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); + cfg.target.addDynamicallyLegalOp< op_t >([](op_t op) { + for (auto ty : op.getResult().getType()) { + if (!is_parser_type(ty)) { + return false; + } + } + return true; + }); } + }; + + struct FuncConversion + : parser_conversion_pattern_base< hl::FuncOp > + , tc::op_type_conversion< hl::FuncOp, function_type_converter > + { + using op_t = hl::FuncOp; + using base = parser_conversion_pattern_base< op_t >; + using base::base; + + using adaptor_t = typename op_t::Adaptor; + logical_result matchAndRewrite( op_t op, adaptor_t adaptor, conversion_rewriter &rewriter ) const override { - auto tc = function_type_converter(*rewriter.getContext(), get_model(op)); + auto tc = function_type_converter(*rewriter.getContext(), get_model(op.getSymName())); if (auto func_op = mlir::dyn_cast< core::function_op_interface >(op.getOperation())) { return this->replace(func_op, rewriter, tc); } @@ -447,7 +487,7 @@ namespace vast::conv { cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >(); cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) { return function_type_converter( - *op.getContext(), get_model(models, op) + *op.getContext(), get_model(models, op.getSymName()) ).isLegal(op.getFunctionType()); }); } @@ -459,6 +499,7 @@ namespace vast::conv { FuncConversion, ParamConversion, DeclRefConversion, + ReturnConversion, CallConversion >;