Skip to content

Commit

Permalink
Merge pull request #175 from JuliaIO/tan/gen
Browse files Browse the repository at this point in the history
correct codegen for nested protobuf structs
  • Loading branch information
tanmaykm authored Apr 7, 2021
2 parents 71aecad + 5ffb048 commit b23a146
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
60 changes: 50 additions & 10 deletions src/gen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,37 @@ const _packages = Dict{String,String}()
# maps protofile name to array of imported protofiles
const protofile_imports = Dict()

# maps resolved names in the "protofile module name" -> "type name" -> ["nested type name",...] form
const name_maps = Dict()

function add_name_map(name, fullname)
parts = rsplit(fullname, "."; limit=2, keepempty=false)

if length(parts) == 1
pkgname = ""
typname = parts[1]
else
pkgname = parts[1]
typname = parts[2]
end

if name == typname
nested_level = ""
else
nested_level = typname[1:end-length(name)-1]
typname = name
end

names_in_module = get!(name_maps, pkgname) do
Dict{String,Set{String}}()
end
nested_type_names = get!(names_in_module, typname) do
Set{String}()
end
push!(nested_type_names, nested_level)
nothing
end

# Treat Google Proto3 extensions specially as they are built into ProtoBuf.jl (for issue #77)
const GOOGLE_PROTO3_EXTENSIONS = "google.protobuf"

Expand Down Expand Up @@ -153,11 +184,23 @@ function field_type_name(full_type_name::String)
if isempty(comps)
type_name = full_type_name
else
package_name = join(comps[1:(end - 1)], '.')
if package_name == GOOGLE_PROTO3_EXTENSIONS
type_name = "ProtoBuf.$full_type_name"
else
type_name = join(comps, '.')
type_name = join(comps, '.')
for level in (length(comps)-1):-1:1
package_name = join(comps[1:level], '.')
if package_name == GOOGLE_PROTO3_EXTENSIONS
type_name = "ProtoBuf.$full_type_name"
break
elseif package_name in keys(name_maps)
type_maps = name_maps[package_name]
if last(comps) in keys(type_maps)
nested_level = join(comps[(level+1):(end-1)], '_')
if nested_level in type_maps[last(comps)]
type_name = isempty(nested_level) ? last(comps) : string(nested_level, "_", last(comps))
isempty(package_name) || (type_name = string(package_name, ".", type_name))
break
end
end
end
end
end
@debug("usable type name for $full_type_name is $type_name")
Expand Down Expand Up @@ -296,6 +339,7 @@ function generate_msgtype(outio::IO, errio::IO, dtype::DescriptorProto, scope::S
modul,dtypename = splitmodule_chkkeyword(full_dtypename)
full_dtypename = (modul=="") ? dtypename : "$(modul).$(dtypename)"
@debug("begin type $(full_dtypename)")
add_name_map(dtype.name, full_dtypename)

scope = Scope(dtype.name, scope)

Expand Down Expand Up @@ -573,11 +617,7 @@ function generate_svc(io::IO, errio::IO, stype::ServiceDescriptorProto, scope::S
println(io, "const _$(stype.name)_methods = MethodDescriptor[")
for idx in 1:nmethods
method = stype.method[idx]
in_typ_name = try
scoped_svc_type_name(scope, field_type_name(method.input_type))
catch ex
throw(string(ex, " - ", fullname(scope)))
end
in_typ_name = scoped_svc_type_name(scope, field_type_name(method.input_type))
out_typ_name = scoped_svc_type_name(scope, field_type_name(method.output_type))
elem_sep = (idx < nmethods) ? "," : ""

Expand Down
4 changes: 4 additions & 0 deletions test/proto/t2.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ message MA {
message MC {
required string c = 1;
}

service TestServicePkgP {
rpc Method1(MA.MB) returns(MC);
}

0 comments on commit b23a146

Please sign in to comment.