Skip to content

Commit

Permalink
Make cast() checked by default and add try_cast() for cases where a cast
Browse files Browse the repository at this point in the history
is not required to be successful.
  • Loading branch information
zrax committed Dec 2, 2022
1 parent 305494c commit ffeabc3
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 38 deletions.
38 changes: 18 additions & 20 deletions ASTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
if (isUninitAsyncFor) {
auto tryBlock = container->nodes().front().cast<ASTBlock>();
if (!tryBlock->nodes().empty() && tryBlock->blktype() == ASTBlock::BLK_TRY) {
auto store = tryBlock->nodes().front().cast<ASTStore>();
auto store = tryBlock->nodes().front().try_cast<ASTStore>();
if (store) {
asyncForBlock.cast<ASTIterBlock>()->setIndex(store->dest());
}
Expand Down Expand Up @@ -1798,7 +1798,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == nullptr && !printNode->eol())
printNode->add(stack.top());
else
Expand All @@ -1813,7 +1813,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)

PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == stream && !printNode->eol())
printNode->add(stack.top());
else
Expand All @@ -1826,7 +1826,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)
{
PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == nullptr && !printNode->eol())
printNode->setEol(true);
else
Expand All @@ -1841,7 +1841,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)

PycRef<ASTPrint> printNode;
if (curblock->size() > 0 && curblock->nodes().back().type() == ASTNode::NODE_PRINT)
printNode = curblock->nodes().back().cast<ASTPrint>();
printNode = curblock->nodes().back().try_cast<ASTPrint>();
if (printNode && printNode->stream() == stream && !printNode->eol())
printNode->setEol(true);
else
Expand Down Expand Up @@ -2152,7 +2152,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)

if (curblock->blktype() == ASTBlock::BLK_FOR
&& !curblock->inited()) {
PycRef<ASTTuple> tuple = tup.cast<ASTTuple>();
PycRef<ASTTuple> tuple = tup.try_cast<ASTTuple>();
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
Expand Down Expand Up @@ -2211,7 +2211,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)

if (curblock->blktype() == ASTBlock::BLK_FOR
&& !curblock->inited()) {
PycRef<ASTTuple> tuple = tup.cast<ASTTuple>();
PycRef<ASTTuple> tuple = tup.try_cast<ASTTuple>();
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
Expand Down Expand Up @@ -2253,7 +2253,7 @@ PycRef<ASTNode> BuildFromCode(PycRef<PycCode> code, PycModule* mod)

if (curblock->blktype() == ASTBlock::BLK_FOR
&& !curblock->inited()) {
PycRef<ASTTuple> tuple = tup.cast<ASTTuple>();
PycRef<ASTTuple> tuple = tup.try_cast<ASTTuple>();
if (tuple != NULL)
tuple->setRequireParens(false);
curblock.cast<ASTIterBlock>()->setIndex(tup);
Expand Down Expand Up @@ -2735,7 +2735,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
if (param.first.type() == ASTNode::NODE_NAME) {
fprintf(pyc_output, "%s = ", param.first.cast<ASTName>()->name()->value());
} else {
PycRef<PycString> str_name = param.first.cast<ASTObject>()->object().require_cast<PycString>();
PycRef<PycString> str_name = param.first.cast<ASTObject>()->object().cast<PycString>();
fprintf(pyc_output, "%s = ", str_name->value());
}
print_src(param.second, mod);
Expand Down Expand Up @@ -2902,20 +2902,18 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
break;
case ASTNode::NODE_BLOCK:
{
if (node.cast<ASTBlock>()->blktype() == ASTBlock::BLK_ELSE
&& node.cast<ASTBlock>()->size() == 0)
PycRef<ASTBlock> blk = node.cast<ASTBlock>();
if (blk->blktype() == ASTBlock::BLK_ELSE && blk->size() == 0)
break;

if (node.cast<ASTBlock>()->blktype() == ASTBlock::BLK_CONTAINER) {
if (blk->blktype() == ASTBlock::BLK_CONTAINER) {
end_line();
PycRef<ASTBlock> blk = node.cast<ASTBlock>();
print_block(blk, mod);
end_line();
break;
}

fprintf(pyc_output, "%s", node.cast<ASTBlock>()->type_str());
PycRef<ASTBlock> blk = node.cast<ASTBlock>();
fprintf(pyc_output, "%s", blk->type_str());
if (blk->blktype() == ASTBlock::BLK_IF
|| blk->blktype() == ASTBlock::BLK_ELIF
|| blk->blktype() == ASTBlock::BLK_WHILE) {
Expand All @@ -2937,7 +2935,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
} else if (blk->blktype() == ASTBlock::BLK_WITH) {
fputs(" ", pyc_output);
print_src(blk.cast<ASTWithBlock>()->expr(), mod);
PycRef<ASTNode> var = blk.cast<ASTWithBlock>()->var();
PycRef<ASTNode> var = blk.try_cast<ASTWithBlock>()->var();
if (var != NULL) {
fputs(" as ", pyc_output);
print_src(var, mod);
Expand Down Expand Up @@ -3248,8 +3246,8 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
print_src(dest, mod);
}
}
} else if (src.type() == ASTNode::NODE_BINARY &&
src.cast<ASTBinary>()->is_inplace() == true) {
} else if (src.type() == ASTNode::NODE_BINARY
&& src.cast<ASTBinary>()->is_inplace()) {
print_src(src, mod);
} else {
print_src(dest, mod);
Expand Down Expand Up @@ -3325,7 +3323,7 @@ void print_src(PycRef<ASTNode> node, PycModule* mod)
PycRef<ASTTernary> ternary = node.cast<ASTTernary>();
//fputs("(", pyc_output);
print_src(ternary->if_expr(), mod);
const auto if_block = ternary->if_block().require_cast<ASTCondBlock>();
const auto if_block = ternary->if_block().cast<ASTCondBlock>();
fputs(" if ", pyc_output);
if (if_block->negative())
fputs("not ", pyc_output);
Expand Down Expand Up @@ -3405,7 +3403,7 @@ void decompyle(PycRef<PycCode> code, PycModule* mod)
if (store->src().type() == ASTNode::NODE_OBJECT
&& store->dest().type() == ASTNode::NODE_NAME) {
PycRef<ASTObject> src = store->src().cast<ASTObject>();
PycRef<PycString> srcString = src->object().cast<PycString>();
PycRef<PycString> srcString = src->object().try_cast<PycString>();
PycRef<ASTName> dest = store->dest().cast<ASTName>();
if (srcString != nullptr && srcString->isEqual(code->name().cast<PycObject>())
&& dest->name()->isEqual("__qualname__")) {
Expand Down
18 changes: 9 additions & 9 deletions pyc_code.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,35 @@ void PycCode::load(PycData* stream, PycModule* mod)
else
m_flags = 0;

m_code = LoadObject(stream, mod).require_cast<PycString>();
m_consts = LoadObject(stream, mod).require_cast<PycSequence>();
m_names = LoadObject(stream, mod).require_cast<PycSequence>();
m_code = LoadObject(stream, mod).cast<PycString>();
m_consts = LoadObject(stream, mod).cast<PycSequence>();
m_names = LoadObject(stream, mod).cast<PycSequence>();

if (mod->verCompare(1, 3) >= 0)
m_varNames = LoadObject(stream, mod).require_cast<PycSequence>();
m_varNames = LoadObject(stream, mod).cast<PycSequence>();
else
m_varNames = new PycTuple;

if (mod->verCompare(2, 1) >= 0)
m_freeVars = LoadObject(stream, mod).require_cast<PycSequence>();
m_freeVars = LoadObject(stream, mod).cast<PycSequence>();
else
m_freeVars = new PycTuple;

if (mod->verCompare(2, 1) >= 0)
m_cellVars = LoadObject(stream, mod).require_cast<PycSequence>();
m_cellVars = LoadObject(stream, mod).cast<PycSequence>();
else
m_cellVars = new PycTuple;

m_fileName = LoadObject(stream, mod).require_cast<PycString>();
m_name = LoadObject(stream, mod).require_cast<PycString>();
m_fileName = LoadObject(stream, mod).cast<PycString>();
m_name = LoadObject(stream, mod).cast<PycString>();

if (mod->verCompare(1, 5) >= 0 && mod->verCompare(2, 3) < 0)
m_firstLine = stream->get16();
else if (mod->verCompare(2, 3) >= 0)
m_firstLine = stream->get32();

if (mod->verCompare(1, 5) >= 0)
m_lnTable = LoadObject(stream, mod).require_cast<PycString>();
m_lnTable = LoadObject(stream, mod).cast<PycString>();
else
m_lnTable = new PycString;
}
8 changes: 4 additions & 4 deletions pyc_code.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,19 @@ class PycCode : public PycObject {

PycRef<PycString> getName(int idx) const
{
return m_names->get(idx).require_cast<PycString>();
return m_names->get(idx).cast<PycString>();
}

PycRef<PycString> getVarName(int idx) const
{
return m_varNames->get(idx).require_cast<PycString>();
return m_varNames->get(idx).cast<PycString>();
}

PycRef<PycString> getCellVar(int idx) const
{
return (idx >= m_cellVars->size())
? m_freeVars->get(idx - m_cellVars->size()).require_cast<PycString>()
: m_cellVars->get(idx).require_cast<PycString>();
? m_freeVars->get(idx - m_cellVars->size()).cast<PycString>()
: m_cellVars->get(idx).cast<PycString>();
}

const globals_t& getGlobals() const { return m_globalsUsed; }
Expand Down
4 changes: 2 additions & 2 deletions pyc_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void PycModule::loadFromFile(const char* filename)
in.get32(); // Size parameter added in Python 3.3
}

m_code = LoadObject(&in, this).require_cast<PycCode>();
m_code = LoadObject(&in, this).cast<PycCode>();
}

void PycModule::loadFromMarshalledFile(const char* filename, int major, int minor)
Expand All @@ -230,7 +230,7 @@ void PycModule::loadFromMarshalledFile(const char* filename, int major, int mino
m_maj = major;
m_min = minor;
m_unicode = (major >= 3);
m_code = LoadObject(&in, this).require_cast<PycCode>();
m_code = LoadObject(&in, this).cast<PycCode>();
}

PycRef<PycString> PycModule::getIntern(int ref) const
Expand Down
4 changes: 2 additions & 2 deletions pyc_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ class PycRef {
inline int type() const;

template <class _Cast>
PycRef<_Cast> cast() const { return dynamic_cast<_Cast*>(m_obj); }
PycRef<_Cast> try_cast() const { return dynamic_cast<_Cast*>(m_obj); }

template <class _Cast>
PycRef<_Cast> require_cast() const
PycRef<_Cast> cast() const
{
_Cast* result = dynamic_cast<_Cast*>(m_obj);
if (!result)
Expand Down
2 changes: 1 addition & 1 deletion pycdas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ int main(int argc, char* argv[])
fprintf(pyc_output, "%s (Python %d.%d%s)\n", dispname, mod.majorVer(), mod.minorVer(),
(mod.majorVer() < 3 && mod.isUnicode()) ? " -U" : "");
try {
output_object(mod.code().cast<PycObject>(), &mod, 0);
output_object(mod.code().try_cast<PycObject>(), &mod, 0);
} catch (std::exception& ex) {
fprintf(stderr, "Error disassembling %s: %s\n", infile, ex.what());
return 1;
Expand Down

0 comments on commit ffeabc3

Please sign in to comment.