diff --git a/src/gen.jl b/src/gen.jl index c47f8ee..e0cc660 100644 --- a/src/gen.jl +++ b/src/gen.jl @@ -20,6 +20,37 @@ const _packages = Dict{String,String}() # maps protofile name to array of imported protofiles const protofile_imports = Dict() +# maps resolved names in the "protofile module name" -> "type name" -> ["nested type name",...] form +const name_maps = Dict() + +function add_name_map(name, fullname) + parts = rsplit(fullname, "."; limit=2, keepempty=false) + + if length(parts) == 1 + pkgname = "" + typname = parts[1] + else + pkgname = parts[1] + typname = parts[2] + end + + if name == typname + nested_level = "" + else + nested_level = typname[1:end-length(name)-1] + typname = name + end + + names_in_module = get!(name_maps, pkgname) do + Dict{String,Set{String}}() + end + nested_type_names = get!(names_in_module, typname) do + Set{String}() + end + push!(nested_type_names, nested_level) + nothing +end + # Treat Google Proto3 extensions specially as they are built into ProtoBuf.jl (for issue #77) const GOOGLE_PROTO3_EXTENSIONS = "google.protobuf" @@ -153,11 +184,23 @@ function field_type_name(full_type_name::String) if isempty(comps) type_name = full_type_name else - package_name = join(comps[1:(end - 1)], '.') - if package_name == GOOGLE_PROTO3_EXTENSIONS - type_name = "ProtoBuf.$full_type_name" - else - type_name = join(comps, '.') + type_name = join(comps, '.') + for level in (length(comps)-1):-1:1 + package_name = join(comps[1:level], '.') + if package_name == GOOGLE_PROTO3_EXTENSIONS + type_name = "ProtoBuf.$full_type_name" + break + elseif package_name in keys(name_maps) + type_maps = name_maps[package_name] + if last(comps) in keys(type_maps) + nested_level = join(comps[(level+1):(end-1)], '_') + if nested_level in type_maps[last(comps)] + type_name = isempty(nested_level) ? last(comps) : string(nested_level, "_", last(comps)) + isempty(package_name) || (type_name = string(package_name, ".", type_name)) + break + end + end + end end end @debug("usable type name for $full_type_name is $type_name") @@ -296,6 +339,7 @@ function generate_msgtype(outio::IO, errio::IO, dtype::DescriptorProto, scope::S modul,dtypename = splitmodule_chkkeyword(full_dtypename) full_dtypename = (modul=="") ? dtypename : "$(modul).$(dtypename)" @debug("begin type $(full_dtypename)") + add_name_map(dtype.name, full_dtypename) scope = Scope(dtype.name, scope) @@ -573,11 +617,7 @@ function generate_svc(io::IO, errio::IO, stype::ServiceDescriptorProto, scope::S println(io, "const _$(stype.name)_methods = MethodDescriptor[") for idx in 1:nmethods method = stype.method[idx] - in_typ_name = try - scoped_svc_type_name(scope, field_type_name(method.input_type)) - catch ex - throw(string(ex, " - ", fullname(scope))) - end + in_typ_name = scoped_svc_type_name(scope, field_type_name(method.input_type)) out_typ_name = scoped_svc_type_name(scope, field_type_name(method.output_type)) elem_sep = (idx < nmethods) ? "," : "" diff --git a/test/proto/t2.proto b/test/proto/t2.proto index b5e6cad..f48840c 100644 --- a/test/proto/t2.proto +++ b/test/proto/t2.proto @@ -11,3 +11,7 @@ message MA { message MC { required string c = 1; } + +service TestServicePkgP { + rpc Method1(MA.MB) returns(MC); +}