Skip to content

Commit

Permalink
Fix schema XSD sources export/download (issue #387)
Browse files Browse the repository at this point in the history
  - Modify dataclass XsdSource: now takes a path and an XMLResorce,
    other attributes are set in __init__;
  - Schema locations now are extracted from XML tree.
  • Loading branch information
brunato committed Apr 7, 2024
1 parent c1cc30c commit aefe4be
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions xmlschema/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from xml.etree import ElementTree

from .exceptions import XMLResourceError, XMLSchemaValueError
from .names import XSD_SCHEMA
from .names import XSD_SCHEMA, XSD_IMPORT, XSD_INCLUDE, XSD_REDEFINE, XSD_OVERRIDE
from .helpers import logged
from .locations import LocationPath, is_remote_url, normalize_url, match_location
from .translation import gettext as _
Expand All @@ -37,15 +37,27 @@
class XsdSource:
"""Class for keeping track of an XSD schema source."""
path: LocationPath
text: str
processed: bool = False
modified: bool = False
substitutions: Optional[List[Tuple[str, str]]] = None
resource: XMLResource

def __init__(self, path: LocationPath, resource: XMLResource) -> None:
self.path = path
self.resource = resource
self.text = resource.get_text()
self.processed = False
self.modified = False
self.substitutions: Optional[List[Tuple[str, str]]] = None

@property
def schema_locations(self) -> Set[str]:
"""Extract schema locations from XSD source."""
return set(x.strip() for x in re.findall(FIND_PATTERN, self.text))
"""Extract schema locations from XSD resource tree."""
locations = set()
for child in self.resource.root:
if child.tag in (XSD_IMPORT, XSD_INCLUDE, XSD_REDEFINE, XSD_OVERRIDE):
schema_location = child.get('schemaLocation', '').strip()
if schema_location:
locations.add(schema_location)

return locations

def replace_location(self, location: str, repl_location: str) -> None:
if location == repl_location:
Expand Down Expand Up @@ -214,7 +226,7 @@ def residuals_filter(x: str) -> bool:
logger.info("Export schema using loglevel %r", loglevel)

name = schema.name or 'schema.xsd'
exports = {schema: XsdSource(LocationPath(name), schema.get_text())}
exports = {schema: XsdSource(LocationPath(name), schema.source)}
path: Any

if exclude_locations is None:
Expand Down Expand Up @@ -257,7 +269,7 @@ def residuals_filter(x: str) -> bool:

path = schema_source.get_location_path(location, ref_schema)
if ref_schema not in exports:
exports[ref_schema] = XsdSource(path, ref_schema.get_text())
exports[ref_schema] = XsdSource(path, ref_schema.source)

if remove_residuals:
# Deactivate residual redundant imports from remote URLs
Expand Down Expand Up @@ -309,7 +321,7 @@ def download_schemas(url: str,

name = resource.name
downloads = {
resource: XsdSource(LocationPath(name), resource.get_text()) # type: ignore[arg-type]
resource: XsdSource(LocationPath(name), resource) # type: ignore[arg-type]
}
path: Any

Expand Down Expand Up @@ -355,7 +367,7 @@ def download_schemas(url: str,
continue

path = schema_source.get_location_path(location, ref_resource, modify)
downloads[ref_resource] = XsdSource(path, ref_resource.get_text())
downloads[ref_resource] = XsdSource(path, ref_resource)

if current_length == len(downloads):
break
Expand Down

0 comments on commit aefe4be

Please sign in to comment.