Skip to content

Commit

Permalink
Support cython 3.0 ast generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabioz committed Jul 30, 2023
1 parent 052e46a commit eb778b0
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ private ISimpleNode createNode(JsonValue jsonValue) {
node = createString(asObject);
break;

case "Annotation":
node = createAnnotation(asObject);
break;

case "Tuple":
node = createTuple(asObject);
break;
Expand Down Expand Up @@ -1150,6 +1154,11 @@ private SimpleNode createString(JsonObject asObject) {
return str;
}

private SimpleNode createAnnotation(JsonObject asObject) {
JsonValue value = asObject.get("expr");
return (SimpleNode) createNode(value);
}

private SimpleNode createFString(JsonObject asObject) throws Exception {
boolean raw = false;
boolean unicode = true;
Expand Down Expand Up @@ -1829,7 +1838,7 @@ private ClassDef createCEnumDef(JsonObject asObject) throws Exception {
}
astFactory.setBody(classDef, nodes.toArray());

classDef.decs = createDecorators(asObject);
classDef.decs = createDecorators(asObject, classDef);
return classDef;
}
return null;
Expand All @@ -1847,7 +1856,7 @@ private ClassDef createClassDef(JsonObject asObject) throws Exception {

astFactory.setBody(classDef, extractStmts(asObject, "body").toArray());

classDef.decs = createDecorators(asObject);
classDef.decs = createDecorators(asObject, classDef);
return classDef;
}
return null;
Expand Down Expand Up @@ -1938,7 +1947,7 @@ private ClassDef createCppClass(JsonObject asObject) throws Exception {

astFactory.setBody(classDef, extractStmts(asObject, "attributes").toArray());

classDef.decs = createDecorators(asObject);
classDef.decs = createDecorators(asObject, classDef);
return classDef;
}
return null;
Expand All @@ -1956,7 +1965,7 @@ private ClassDef createCClassDef(JsonObject asObject) throws Exception {

astFactory.setBody(classDef, extractStmts(asObject, "body").toArray());

classDef.decs = createDecorators(asObject);
classDef.decs = createDecorators(asObject, classDef);
return classDef;
}
return null;
Expand All @@ -1979,7 +1988,7 @@ private FunctionDef createFunctionDef(JsonObject asObject) throws Exception {
funcDef.name.beginColumn = funcDef.beginColumn + 4;

funcDef.args = createArgs(asObject);
funcDef.decs = createDecorators(asObject);
funcDef.decs = createDecorators(asObject, funcDef);

JsonValue isAsyncDef = asObject.get("is_async_def");
if (isAsyncDef != null && isAsyncDef.asString().equals("True")) {
Expand Down Expand Up @@ -2011,7 +2020,7 @@ private FunctionDef createCFunctionDef(JsonObject asObject) throws Exception {
if (declarator != null && declarator.isObject()) {
FunctionDef funcDef = createCFuncDeclarator(declarator.asObject());
if (funcDef != null) {
funcDef.decs = createDecorators(asObject);
funcDef.decs = createDecorators(asObject, funcDef);
setLine(funcDef, asObject);
setLine(funcDef.name, asObject);
astFactory.setBody(funcDef, extractStmts(asObject, "body").toArray());
Expand Down Expand Up @@ -2043,7 +2052,7 @@ public List<exprType> extractExprs(JsonObject asObject, String field) {
return lst;
}

private decoratorsType[] createDecorators(JsonObject asObject) throws Exception {
private decoratorsType[] createDecorators(JsonObject asObject, SimpleNode defNode) throws Exception {
List<decoratorsType> decs = new ArrayList<decoratorsType>();
JsonValue jsonValue = asObject.get("decorators");
if (jsonValue != null && jsonValue.isArray()) {
Expand All @@ -2057,6 +2066,25 @@ private decoratorsType[] createDecorators(JsonObject asObject) throws Exception
if (decs.size() == 0) {
return null;
}

// In cython 3 the dec line matches the first decorator line.
int maxDecLine = -1;
for (decoratorsType dec : decs) {
if (dec.beginLine > maxDecLine) {
maxDecLine = dec.beginLine;
}
}

if (defNode.beginLine <= maxDecLine) {
defNode.beginLine = maxDecLine + 1;
if (defNode instanceof FunctionDef) {
FunctionDef functionDef = (FunctionDef) defNode;
functionDef.name.beginLine = defNode.beginLine;
} else if (defNode instanceof ClassDef) {
ClassDef classDef = (ClassDef) defNode;
classDef.name.beginLine = defNode.beginLine;
}
}
return decs.toArray(new decoratorsType[0]);
}

Expand All @@ -2077,9 +2105,11 @@ private decoratorsType createDecorator(JsonValue v) throws Exception {
decorator.starargs = call.starargs;
decorator.kwargs = call.kwargs;
decorator.isCall = true;
decorator.beginLine = call.beginLine;

} else if (func instanceof exprType) {
decorator.func = (exprType) func;
decorator.beginLine = decorator.func.beginLine;

} else {
if (func != null) {
Expand Down Expand Up @@ -2141,7 +2171,8 @@ private argumentsType createArgs(JsonObject funcAsObject) {
boolean isKwOnly = false;
JsonValue kwOnlyValue = asObject.get("kw_only");
if (kwOnlyValue != null && kwOnlyValue.isString()) {
if ("1".equals(kwOnlyValue.asString())) {
final String asString = kwOnlyValue.asString();
if ("1".equals(asString) || "True".equals(asString)) {
isKwOnly = true;
}
}
Expand Down Expand Up @@ -2381,10 +2412,20 @@ private SimpleNode createSingleAssignment(JsonObject asObject) throws Exception

if (left instanceof Name) {
Name leftName = (Name) left;
aliasType[] names = ((Import) right).names;
Import importNode = ((Import) right);
aliasType[] names = importNode.names;
aliasType aliasType = names[0];

if (aliasType.asname != null) {
if (aliasType.asname == null) {
if (leftName != null && leftName.id != null) {
if (((NameTok) aliasType.name).id == null
|| !leftName.id.equals(((NameTok) aliasType.name).id)) {
aliasType.asname = new NameTok(leftName.id, NameTok.ImportName);
aliasType.asname.beginColumn = leftName.beginColumn;
aliasType.asname.beginLine = leftName.beginLine;
}
}
} else {
aliasType.asname = new NameTok(leftName.id, NameTok.ImportName);
aliasType.asname.beginColumn = leftName.beginColumn;
aliasType.asname.beginLine = leftName.beginLine;
Expand Down Expand Up @@ -2538,7 +2579,11 @@ private While createWhile(JsonObject asObject) {
}

private Assert createAssert(JsonObject asObject) {
exprType cond = asExpr(createNode(asObject.get("cond")));
JsonValue condition = asObject.get("cond");
if (condition == null) {
condition = asObject.get("condition");
}
exprType cond = asExpr(createNode(condition));
exprType value = asExpr(createNode(asObject.get("value")));

Assert assertStmt = new Assert(cond, value);
Expand Down Expand Up @@ -2889,7 +2934,11 @@ public String genCythonJson() {
CythonShell serverShell = (CythonShell) AbstractShell.getServerShell(nature,
CompletionProposalFactory.get().getCythonShellId());
String contents = parserInfo.document.get();
return serverShell.convertToJsonAst(StringUtils.replaceNewLines(contents, "\n"));
String ret = serverShell.convertToJsonAst(StringUtils.replaceNewLines(contents, "\n"));
System.out.println("---");
System.out.println(JsonValue.readFrom(ret).toPrettyString());
System.out.println("---");
return ret;
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import traceback
import os

try:
from Cython.Compiler import Errors
Errors.init_thread() # This is needed in Cython 3.0.0 (otherwise reporting errors will throw exception).
except Exception:
pass

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Note: Cython has some recursive structures in some classes, so, parsing only what we really
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.python.pydev.parser.jython.ast.Name;
import org.python.pydev.parser.jython.ast.NameTok;
import org.python.pydev.parser.jython.ast.Suite;
import org.python.pydev.parser.visitors.NodeUtils;
import org.python.pydev.parser.visitors.comparator.DifferException;
import org.python.pydev.parser.visitors.comparator.SimpleNodeComparator;
import org.python.pydev.parser.visitors.comparator.SimpleNodeComparator.LineColComparator;
Expand Down Expand Up @@ -105,6 +106,15 @@ public void testGenCythonFromCythonTests() throws Exception {

public void testGenCythonAstCases() throws Exception {
String[] cases = new String[] {
"def method(a, *, b):pass",
"@dec1\n@dec2\ndef method():pass",
"@dec\ndef method():pass",
"assert 1, 'ra'",
"def method(a:int): return 1",
"import a.b as c",
"import a\n"
+ "\n"
+ "import b\n",
"def method(*args, **kwargs):\n"
+ " return f(*args, **modify(kwargs))",
"{tuple(call(n) for n in (1, 2) if n == 2)}",
Expand All @@ -118,9 +128,6 @@ public void testGenCythonAstCases() throws Exception {
+ " yield from bar",
"a = lambda x:y",
"a = call(foo, foo=bar, **xx.yy)",
"import a\n"
+ "\n"
+ "import b\n",

"foo[:, b:c, d:e, f:g] = []",
"foo[:] = []",
Expand Down Expand Up @@ -183,14 +190,10 @@ public void testGenCythonAstCases() throws Exception {
"from a import b",
"from a.b import d as f",
"from a import b as c",
"import a.b as c",
"import a.b",
"import a",
"1 | 2 == 0",
"1 & 2 == 0",
"1 ^ 2 == 0",
"def method(a:int): return 1",
"assert 1, 'ra'",
"a = a + b",
"a = a - b",
"a = a * b",
Expand Down Expand Up @@ -220,15 +223,13 @@ public void testGenCythonAstCases() throws Exception {
"def method():\n a=10\n b=20",
"def method(a=None):pass",
"def method(a, *b, **c):pass",
"def method(a, *, b):pass",
"def method(a=1, *, b=2):pass",
"def method(a=1, *, b:int=2):pass",
"call()",
"call(1, 2, b(True, False))",
"call(u'nth')",
"call(b'nth')",
"call('nth')",
"@dec\ndef method():pass",
"@dec()\ndef method():pass",
"class A:pass",
"class A(set, object):pass",
Expand Down Expand Up @@ -261,8 +262,10 @@ public ParseOutput compareCase(String expected) throws DifferException, Exceptio
public ParseOutput compareCase(String expected, String cython) throws DifferException, Exception {
try {
return compareCase(expected, cython, false);
} catch (Exception e) {
throw new RuntimeException("Error with cython: " + cython, e);
} catch (Throwable e) {
final String msg = "Error with cython: " + cython;
System.err.println(msg);
throw new RuntimeException(msg, e);
}
}

Expand Down Expand Up @@ -302,7 +305,13 @@ public ParseOutput compareCase(String expected, String cython, LineColComparator
throw new RuntimeException("Error parsing: " + cython);
}

compareNodes(parseOutput.ast, cythonParseOutput.ast, lineColComparator);
try {
compareNodes(parseOutput.ast, cythonParseOutput.ast, lineColComparator);
} catch (Throwable e) {
System.err.println("Cython AST pretty-printed to: ");
System.err.println(NodeUtils.printAst(null, (SimpleNode) cythonParseOutput.ast));
throw e;
}
return cythonParseOutput;
}

Expand Down Expand Up @@ -334,13 +343,26 @@ public void compareWithAst(String code, String expectedAst) throws Misconfigurat
assertEquals(expectedAst, cythonParseOutput.ast.toString());
}

public void compareWithAst(String code, String[] expectedAstArray) throws MisconfigurationException {
ParserInfo parserInfo = new ParserInfo(new Document(code), grammarVersionProvider);
ParseOutput cythonParseOutput = new GenCythonAstImpl(parserInfo).genCythonAst();
String found = cythonParseOutput.ast.toString();
for (String s : expectedAstArray) {
if (s.equals(found)) {
return;
}
}
throw new AssertionError("Error: generated: " + found + "\n\nDoes not match any of the expected arrays.");
}

public void testGenCythonAstCornerCase1() throws Exception {
compareWithAst("(f'{a}{{}}nth')",
"Module[body=[Expr[value=Str[s=, type=SingleSingle, unicode=true, raw=false, binary=false, fstring=false, fstring_nodes=[Expr[value=Name[id=a, ctx=Load, reserved=false]], Expr[value=Str[s={}nth, type=SingleSingle, unicode=true, raw=false, binary=false, fstring=false, fstring_nodes=null]]]]]]]");
}

public void testGenCythonAstCornerCase2() throws Exception {
compareCase("a = u'>'", "a = c'>'");
compareCase("import a.b", "import a.b as a");

compareCase(
"\n"
Expand Down Expand Up @@ -407,7 +429,11 @@ public void testGenCythonAstCornerCase6() throws Exception {

public void testGenCythonAstCornerCase7() throws Exception {
compareWithAst("print(10)",
"Module[body=[Print[dest=null, values=[Num[n=10, type=Int, num=10]], nl=true]]]");
new String[] {

"Module[body=[Print[dest=null, values=[Num[n=10, type=Int, num=10]], nl=true]]]",
"Module[body=[Expr[value=Call[func=Name[id=print, ctx=Load, reserved=false], args=[Num[n=10, type=Int, num=10]], keywords=[], starargs=null, kwargs=null]]]]"
});

}

Expand Down Expand Up @@ -596,7 +622,33 @@ public void testGenCythonAstBasic() throws Exception {
JsonValue value = JsonValue.readFrom(output);

JsonValue body = value.asObject().get("ast").asObject().get("stats");
assertEquals(body, JsonValue.readFrom(
Object expect1 = JsonValue.readFrom(
"[ \n"
+ " { \n"
+ " \"__node__\": \"SingleAssignment\",\n"
+ " \"line\": 1, \n"
+ " \"col\": 4, \n"
+ " \"lhs\": { \n"
+ " \"__node__\": \"Name\", \n"
+ " \"line\": 1, \n"
+ " \"col\": 0, \n"
+ " \"name\": \"a\" \n"
+ " }, \n"
+ " \"rhs\": { \n"
+ " \"__node__\": \"Int\", \n"
+ " \"line\": 1, \n"
+ " \"col\": 4, \n"
+ " \"is_c_literal\": \"None\", \n"
+ " \"value\": \"10\", \n"
+ " \"unsigned\": \"\", \n"
+ " \"longness\": \"\", \n"
+ " \"constant_result\": \"10\", \n"
+ " \"type\": \"long\" \n"
+ " }, \n"
+ " \"first\": \"False\" \n" // This is new in Cython 3.0
+ " } \n"
+ "] ");
Object expect2 = JsonValue.readFrom(
"[\n" +
" {\n" +
" \"__node__\": \"SingleAssignment\",\n" +
Expand All @@ -621,7 +673,11 @@ public void testGenCythonAstBasic() throws Exception {
" }\n" +
" }\n" +
" ]\n" +
"\n"));
"\n");

if (!body.equals(expect1) && !body.equals(expect2)) {
throw new AssertionError("The body json doesn't match what we expect:\n" + body);
}

assertEquals(
"Module[body=[Assign[targets=[Name[id=a, ctx=Store, reserved=false]], value=Num[n=10, type=Int, num=10], type=null]]]",
Expand Down

0 comments on commit eb778b0

Please sign in to comment.