From ffeabc3d3f50bcb0aa4c91ae3b933a106da6520b Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 1 Dec 2022 16:13:31 -0800 Subject: [PATCH] Make cast() checked by default and add try_cast() for cases where a cast is not required to be successful. --- ASTree.cpp | 38 ++++++++++++++++++-------------------- pyc_code.cpp | 18 +++++++++--------- pyc_code.h | 8 ++++---- pyc_module.cpp | 4 ++-- pyc_object.h | 4 ++-- pycdas.cpp | 2 +- 6 files changed, 36 insertions(+), 38 deletions(-) diff --git a/ASTree.cpp b/ASTree.cpp index 90c79e1e..4e327ffd 100644 --- a/ASTree.cpp +++ b/ASTree.cpp @@ -841,7 +841,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) if (isUninitAsyncFor) { auto tryBlock = container->nodes().front().cast(); if (!tryBlock->nodes().empty() && tryBlock->blktype() == ASTBlock::BLK_TRY) { - auto store = tryBlock->nodes().front().cast(); + auto store = tryBlock->nodes().front().try_cast(); if (store) { asyncForBlock.cast()->setIndex(store->dest()); } @@ -1798,7 +1798,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) { PycRef printNode; if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT) - printNode = curblock->nodes().back().cast(); + printNode = curblock->nodes().back().try_cast(); if (printNode && printNode->stream() == nullptr && !printNode->eol()) printNode->add(stack.top()); else @@ -1813,7 +1813,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) PycRef printNode; if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT) - printNode = curblock->nodes().back().cast(); + printNode = curblock->nodes().back().try_cast(); if (printNode && printNode->stream() == stream && !printNode->eol()) printNode->add(stack.top()); else @@ -1826,7 +1826,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) { PycRef printNode; if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT) - printNode = curblock->nodes().back().cast(); + printNode = curblock->nodes().back().try_cast(); if (printNode && printNode->stream() == nullptr && !printNode->eol()) printNode->setEol(true); else @@ -1841,7 +1841,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) PycRef printNode; if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT) - printNode = curblock->nodes().back().cast(); + printNode = curblock->nodes().back().try_cast(); if (printNode && printNode->stream() == stream && !printNode->eol()) printNode->setEol(true); else @@ -2152,7 +2152,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) if (curblock->blktype() == ASTBlock::BLK_FOR && !curblock->inited()) { - PycRef tuple = tup.cast(); + PycRef tuple = tup.try_cast(); if (tuple != NULL) tuple->setRequireParens(false); curblock.cast()->setIndex(tup); @@ -2211,7 +2211,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) if (curblock->blktype() == ASTBlock::BLK_FOR && !curblock->inited()) { - PycRef tuple = tup.cast(); + PycRef tuple = tup.try_cast(); if (tuple != NULL) tuple->setRequireParens(false); curblock.cast()->setIndex(tup); @@ -2253,7 +2253,7 @@ PycRef BuildFromCode(PycRef code, PycModule* mod) if (curblock->blktype() == ASTBlock::BLK_FOR && !curblock->inited()) { - PycRef tuple = tup.cast(); + PycRef tuple = tup.try_cast(); if (tuple != NULL) tuple->setRequireParens(false); curblock.cast()->setIndex(tup); @@ -2735,7 +2735,7 @@ void print_src(PycRef node, PycModule* mod) if (param.first.type() == ASTNode::NODE_NAME) { fprintf(pyc_output, "%s = ", param.first.cast()->name()->value()); } else { - PycRef str_name = param.first.cast()->object().require_cast(); + PycRef str_name = param.first.cast()->object().cast(); fprintf(pyc_output, "%s = ", str_name->value()); } print_src(param.second, mod); @@ -2902,20 +2902,18 @@ void print_src(PycRef node, PycModule* mod) break; case ASTNode::NODE_BLOCK: { - if (node.cast()->blktype() == ASTBlock::BLK_ELSE - && node.cast()->size() == 0) + PycRef blk = node.cast(); + if (blk->blktype() == ASTBlock::BLK_ELSE && blk->size() == 0) break; - if (node.cast()->blktype() == ASTBlock::BLK_CONTAINER) { + if (blk->blktype() == ASTBlock::BLK_CONTAINER) { end_line(); - PycRef blk = node.cast(); print_block(blk, mod); end_line(); break; } - fprintf(pyc_output, "%s", node.cast()->type_str()); - PycRef blk = node.cast(); + fprintf(pyc_output, "%s", blk->type_str()); if (blk->blktype() == ASTBlock::BLK_IF || blk->blktype() == ASTBlock::BLK_ELIF || blk->blktype() == ASTBlock::BLK_WHILE) { @@ -2937,7 +2935,7 @@ void print_src(PycRef node, PycModule* mod) } else if (blk->blktype() == ASTBlock::BLK_WITH) { fputs(" ", pyc_output); print_src(blk.cast()->expr(), mod); - PycRef var = blk.cast()->var(); + PycRef var = blk.try_cast()->var(); if (var != NULL) { fputs(" as ", pyc_output); print_src(var, mod); @@ -3248,8 +3246,8 @@ void print_src(PycRef node, PycModule* mod) print_src(dest, mod); } } - } else if (src.type() == ASTNode::NODE_BINARY && - src.cast()->is_inplace() == true) { + } else if (src.type() == ASTNode::NODE_BINARY + && src.cast()->is_inplace()) { print_src(src, mod); } else { print_src(dest, mod); @@ -3325,7 +3323,7 @@ void print_src(PycRef node, PycModule* mod) PycRef ternary = node.cast(); //fputs("(", pyc_output); print_src(ternary->if_expr(), mod); - const auto if_block = ternary->if_block().require_cast(); + const auto if_block = ternary->if_block().cast(); fputs(" if ", pyc_output); if (if_block->negative()) fputs("not ", pyc_output); @@ -3405,7 +3403,7 @@ void decompyle(PycRef code, PycModule* mod) if (store->src().type() == ASTNode::NODE_OBJECT && store->dest().type() == ASTNode::NODE_NAME) { PycRef src = store->src().cast(); - PycRef srcString = src->object().cast(); + PycRef srcString = src->object().try_cast(); PycRef dest = store->dest().cast(); if (srcString != nullptr && srcString->isEqual(code->name().cast()) && dest->name()->isEqual("__qualname__")) { diff --git a/pyc_code.cpp b/pyc_code.cpp index 06ce5f39..38a1013b 100644 --- a/pyc_code.cpp +++ b/pyc_code.cpp @@ -40,27 +40,27 @@ void PycCode::load(PycData* stream, PycModule* mod) else m_flags = 0; - m_code = LoadObject(stream, mod).require_cast(); - m_consts = LoadObject(stream, mod).require_cast(); - m_names = LoadObject(stream, mod).require_cast(); + m_code = LoadObject(stream, mod).cast(); + m_consts = LoadObject(stream, mod).cast(); + m_names = LoadObject(stream, mod).cast(); if (mod->verCompare(1, 3) >= 0) - m_varNames = LoadObject(stream, mod).require_cast(); + m_varNames = LoadObject(stream, mod).cast(); else m_varNames = new PycTuple; if (mod->verCompare(2, 1) >= 0) - m_freeVars = LoadObject(stream, mod).require_cast(); + m_freeVars = LoadObject(stream, mod).cast(); else m_freeVars = new PycTuple; if (mod->verCompare(2, 1) >= 0) - m_cellVars = LoadObject(stream, mod).require_cast(); + m_cellVars = LoadObject(stream, mod).cast(); else m_cellVars = new PycTuple; - m_fileName = LoadObject(stream, mod).require_cast(); - m_name = LoadObject(stream, mod).require_cast(); + m_fileName = LoadObject(stream, mod).cast(); + m_name = LoadObject(stream, mod).cast(); if (mod->verCompare(1, 5) >= 0 && mod->verCompare(2, 3) < 0) m_firstLine = stream->get16(); @@ -68,7 +68,7 @@ void PycCode::load(PycData* stream, PycModule* mod) m_firstLine = stream->get32(); if (mod->verCompare(1, 5) >= 0) - m_lnTable = LoadObject(stream, mod).require_cast(); + m_lnTable = LoadObject(stream, mod).cast(); else m_lnTable = new PycString; } diff --git a/pyc_code.h b/pyc_code.h index bf0fa1d5..588d23b3 100644 --- a/pyc_code.h +++ b/pyc_code.h @@ -58,19 +58,19 @@ class PycCode : public PycObject { PycRef getName(int idx) const { - return m_names->get(idx).require_cast(); + return m_names->get(idx).cast(); } PycRef getVarName(int idx) const { - return m_varNames->get(idx).require_cast(); + return m_varNames->get(idx).cast(); } PycRef getCellVar(int idx) const { return (idx >= m_cellVars->size()) - ? m_freeVars->get(idx - m_cellVars->size()).require_cast() - : m_cellVars->get(idx).require_cast(); + ? m_freeVars->get(idx - m_cellVars->size()).cast() + : m_cellVars->get(idx).cast(); } const globals_t& getGlobals() const { return m_globalsUsed; } diff --git a/pyc_module.cpp b/pyc_module.cpp index aa739e51..4b95666c 100644 --- a/pyc_module.cpp +++ b/pyc_module.cpp @@ -213,7 +213,7 @@ void PycModule::loadFromFile(const char* filename) in.get32(); // Size parameter added in Python 3.3 } - m_code = LoadObject(&in, this).require_cast(); + m_code = LoadObject(&in, this).cast(); } void PycModule::loadFromMarshalledFile(const char* filename, int major, int minor) @@ -230,7 +230,7 @@ void PycModule::loadFromMarshalledFile(const char* filename, int major, int mino m_maj = major; m_min = minor; m_unicode = (major >= 3); - m_code = LoadObject(&in, this).require_cast(); + m_code = LoadObject(&in, this).cast(); } PycRef PycModule::getIntern(int ref) const diff --git a/pyc_object.h b/pyc_object.h index fd7cc1ac..19ca5743 100644 --- a/pyc_object.h +++ b/pyc_object.h @@ -70,10 +70,10 @@ class PycRef { inline int type() const; template - PycRef<_Cast> cast() const { return dynamic_cast<_Cast*>(m_obj); } + PycRef<_Cast> try_cast() const { return dynamic_cast<_Cast*>(m_obj); } template - PycRef<_Cast> require_cast() const + PycRef<_Cast> cast() const { _Cast* result = dynamic_cast<_Cast*>(m_obj); if (!result) diff --git a/pycdas.cpp b/pycdas.cpp index bf540110..e0fb5f98 100644 --- a/pycdas.cpp +++ b/pycdas.cpp @@ -304,7 +304,7 @@ int main(int argc, char* argv[]) fprintf(pyc_output, "%s (Python %d.%d%s)\n", dispname, mod.majorVer(), mod.minorVer(), (mod.majorVer() < 3 && mod.isUnicode()) ? " -U" : ""); try { - output_object(mod.code().cast(), &mod, 0); + output_object(mod.code().try_cast(), &mod, 0); } catch (std::exception& ex) { fprintf(stderr, "Error disassembling %s: %s\n", infile, ex.what()); return 1;