Source code for agentbx.schemas.generator

# src/agentbx/schemas/generator.py
"""
Schema generator for agentbx.

This module generates Pydantic schemas from YAML definitions.
"""

from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional

import yaml
from pydantic import BaseModel


[docs]class AssetDefinition(BaseModel): """Pydantic model for individual asset definitions from YAML.""" description: str data_type: str shape: Optional[str] = None units: Optional[str] = None checksum_required: bool = False required_attributes: Optional[List[str]] = None required_methods: Optional[List[str]] = None depends_on: Optional[List[str]] = None must_be_complex: Optional[bool] = None must_be_real: Optional[bool] = None data_must_be_positive: Optional[bool] = None default_values: Optional[Dict[str, Any]] = None allowed_values: Optional[List[str]] = None valid_range: Optional[List[float]] = None required_keys: Optional[List[str]] = None optional_keys: Optional[List[str]] = None expected_keys: Optional[List[str]] = None
[docs]class ValidationRule(BaseModel): """Pydantic model for validation rules from YAML.""" rule_name: str parameters: Optional[Dict[str, Any]] = None
[docs]class WorkflowPattern(BaseModel): """Pydantic model for workflow patterns from YAML.""" pattern_name: str requires: Optional[List[str]] = None produces: Optional[List[str]] = None modifies: Optional[List[str]] = None preserves: Optional[List[str]] = None method: Optional[str] = None enables: Optional[List[str]] = None input: Optional[List[str]] = None output: Optional[List[str]] = None process: Optional[str] = None checks: Optional[List[str]] = None outputs: Optional[List[str]] = None validates: Optional[List[str]] = None
[docs]class SchemaDefinition(BaseModel): """Complete schema definition parsed from YAML.""" task_type: str description: str required_assets: List[str] optional_assets: Optional[List[str]] = [] asset_definitions: Dict[str, AssetDefinition] validation_rules: Optional[Dict[str, Dict[str, Any]]] = ( {} ) # Changed from List to Dict workflow_patterns: Optional[Dict[str, WorkflowPattern]] = {} dependencies: Optional[List[str]] = [] produces_for: Optional[List[str]] = []
[docs]class SchemaGenerator: """Generates Pydantic models from YAML schema definitions."""
[docs] def __init__(self, schema_dir: Path): self.schema_dir = Path(schema_dir) self.schemas: Dict[str, SchemaDefinition] = {} self.generated_models: Dict[str, type] = {}
[docs] def load_schema(self, schema_file: Path) -> SchemaDefinition: """Load and parse a single YAML schema file.""" try: print(f"🔍 Loading {schema_file.name}...") with open(schema_file) as f: raw_schema = yaml.safe_load(f) if raw_schema is None: raise ValueError(f"Empty or invalid YAML file: {schema_file}") print(f" ✅ YAML loaded, keys: {list(raw_schema.keys())}") # Convert asset_definitions to AssetDefinition objects print(" 🔧 Processing asset_definitions...") asset_defs = {} asset_definitions_raw = raw_schema.get("asset_definitions", {}) if asset_definitions_raw: for asset_name, asset_data in asset_definitions_raw.items(): if asset_data is None: print(f"Warning: asset_data is None for {asset_name}") continue try: asset_defs[asset_name] = AssetDefinition(**asset_data) print(f" ✅ {asset_name}") except Exception as e: print( " ❌ Error creating AssetDefinition for {}: {}".format( asset_name, e ) ) continue # Convert workflow_patterns to WorkflowPattern objects print(" 🔧 Processing workflow_patterns...") workflow_patterns = {} workflow_patterns_raw = raw_schema.get("workflow_patterns", {}) if workflow_patterns_raw: if isinstance(workflow_patterns_raw, dict): for pattern_name, pattern_data in workflow_patterns_raw.items(): print( " 🔍 Processing pattern '{}': {}".format( pattern_name, type(pattern_data) ) ) if pattern_data is None: pattern_data = {} # Handle different workflow pattern structures if isinstance(pattern_data, list): print( " 🔄 Converting list to dict for {}".format( pattern_name ) ) # Convert list format to dict format combined_data = {"pattern_name": pattern_name} for item in pattern_data: if isinstance(item, dict): combined_data.update(item) print(f" 📝 Added: {list(item.keys())}") elif isinstance(item, str): print(f" 📝 String item: {item}") # Handle simple string items if ":" in item: key, val = item.split(":", 1) combined_data[key.strip()] = val.strip().strip( '"' ) pattern_data = combined_data print( " ✅ Converted to: {}".format( list(pattern_data.keys()) ) ) elif isinstance(pattern_data, dict): pattern_data["pattern_name"] = pattern_name print( " ✅ Dict format: {}".format( list(pattern_data.keys()) ) ) else: print( " ⚠️ Unexpected pattern_data type: {}".format( type(pattern_data) ) ) continue try: workflow_patterns[pattern_name] = WorkflowPattern( **pattern_data ) print(f" ✅ {pattern_name}") except Exception as e: print( " ❌ Error creating WorkflowPattern for {}: {}".format( pattern_name, e ) ) print(f" Pattern data: {pattern_data}") continue else: print( "⚠️ Unexpected workflow_patterns type: {} (should be dict)".format( type(workflow_patterns_raw) ) ) # Handle validation_rules safely print(" 🔧 Processing validation_rules...") validation_rules = raw_schema.get("validation_rules", {}) if validation_rules is None: validation_rules = {} # Convert validation rules from list format to dict format normalized_validation_rules = {} if isinstance(validation_rules, dict): for asset_name, rules in validation_rules.items(): print( " 🔍 Processing validation for '{}': {}".format( asset_name, type(rules) ) ) if isinstance(rules, list): print(f" 🔄 Converting list to dict for {asset_name}") # Flatten list of dicts into single dict combined_rules = {} for rule_item in rules: if isinstance(rule_item, dict): combined_rules.update(rule_item) print( " 📝 Added rules: {}".format( list(rule_item.keys()) ) ) elif isinstance(rule_item, str): # Handle simple string rules combined_rules[rule_item] = True print(f" 📝 Added string rule: {rule_item}") normalized_validation_rules[asset_name] = combined_rules print( " ✅ Final rules: {}".format( list(combined_rules.keys()) ) ) elif isinstance(rules, dict): normalized_validation_rules[asset_name] = rules print(f" ✅ Dict format: {list(rules.keys())}") else: print( "Warning: Expected dict for validation rules, got {} for {}".format( type(rules), asset_name ) ) validation_rules = normalized_validation_rules print( " ✅ Validation rules processed: {}".format( list(validation_rules.keys()) ) ) # Create the full schema definition print(" 🔧 Creating SchemaDefinition...") schema_data = { "task_type": raw_schema.get("task_type", ""), "description": raw_schema.get("description", ""), "required_assets": raw_schema.get("required_assets", []), "optional_assets": raw_schema.get("optional_assets", []), "asset_definitions": asset_defs, "validation_rules": validation_rules, "workflow_patterns": workflow_patterns, "dependencies": raw_schema.get("dependencies", []), "produces_for": raw_schema.get("produces_for", []), } schema = SchemaDefinition(**schema_data) print(" ✅ SchemaDefinition created successfully") return schema except Exception as e: print(f"❌ Error loading {schema_file.name}: {e}") print(f" Error type: {type(e).__name__}") import traceback traceback.print_exc() # pragma: no cover raise # pragma: no cover
[docs] def load_all_schemas(self) -> None: """Load all YAML schema files from the schema directory.""" schema_files = list(self.schema_dir.glob("*.yaml")) if not schema_files: print(f"No .yaml files found in {self.schema_dir}") return print(f"Found {len(schema_files)} YAML files:") for schema_file in schema_files: print(f" - {schema_file.name}") for schema_file in schema_files: try: schema = self.load_schema(schema_file) self.schemas[schema.task_type] = schema print(f"✅ Loaded schema: {schema.task_type}") except Exception as e: print(f"❌ Error loading {schema_file.name}: {e}")
# Continue with other files instead of stopping
[docs] def generate_asset_model(self, asset_name: str, asset_def: AssetDefinition) -> str: """Generate Pydantic field definition for an asset.""" # Map CCTBX types to Python types for Pydantic type_mapping = { "cctbx.xray.structure": "Any", # Will need custom validation "cctbx.miller.array": "Any", "cctbx.miller.set": "Any", "cctbx.array_family.flex.vec3_double": "Any", "cctbx.array_family.flex.double": "Any", "dict": "Dict[str, Any]", "str": "str", "float": "float", "int": "int", "bool": "bool", "bytes": "bytes", } python_type = type_mapping.get(asset_def.data_type, "Any") # Build field definition field_kwargs = [] if asset_def.description: field_kwargs.append(f'description="{asset_def.description}"') # Add constraints based on asset definition if asset_def.valid_range: if python_type == "float": field_kwargs.append( "ge={}, le={}".format( asset_def.valid_range[0], asset_def.valid_range[1] ) ) if asset_def.default_values and len(asset_def.default_values) == 1: default_val = list(asset_def.default_values.values())[0] field_kwargs.append(f"default={repr(default_val)}") field_def = ( "Field({})".format(", ".join(field_kwargs)) if field_kwargs else "..." ) return f" {asset_name}: {python_type} = {field_def}"
[docs] def generate_validators(self, schema: SchemaDefinition) -> List[str]: """Generate custom validators for CCTBX-specific validation.""" validators = [] # Generate validators for each asset with validation rules if schema.validation_rules: for asset_name, rules in schema.validation_rules.items(): if asset_name in schema.asset_definitions: asset_def = schema.asset_definitions[asset_name] validator_lines = [ f" @field_validator('{asset_name}')", " @classmethod", f" def validate_{asset_name}(cls, v):", f' """Validate {asset_def.description}"""', ] # Add CCTBX-specific validation if asset_def.data_type == "cctbx.xray.structure": validator_lines.extend( [ " if not hasattr(v, 'scatterers'):", " raise ValueError('xray_structure must have scatterers')", " if not hasattr(v, 'unit_cell'):", " raise ValueError('xray_structure must have unit_cell')", ] ) elif asset_def.data_type == "cctbx.miller.array": validator_lines.extend( [ " if not hasattr(v, 'indices'):", " raise ValueError('miller_array must have indices')", " if not hasattr(v, 'data'):", " raise ValueError('miller_array must have data')", ] ) if asset_def.must_be_complex: validator_lines.append( " if not v.is_complex_array():" ) validator_lines.append( " raise ValueError('miller_array must be complex')" ) if asset_def.data_must_be_positive: validator_lines.append( " if (v.data() < 0).count(True) > 0:" ) validator_lines.append( " raise ValueError('miller_array data must be positive')" ) # Handle validation rules - rules is now a dict, not a list if isinstance(rules, dict): for rule_name, rule_value in rules.items(): if rule_name == "finite_values_only" and rule_value: validator_lines.extend( [ " import numpy as np", " if hasattr(v, 'data') and not np.all(np.isfinite(v.data())):", " raise ValueError('All values must be finite')", ] ) else: print( "Warning: Expected dict for validation rules, got {} for {}".format( type(rules), asset_name ) ) validator_lines.append(" return v") validator_lines.append("") validators.extend(validator_lines) return validators
[docs] def generate_bundle_model(self, schema: SchemaDefinition) -> str: """Generate a complete Pydantic model for a bundle type.""" class_name = "{}Bundle".format(schema.task_type.title().replace("_", "")) lines = [ f"class {class_name}(BaseModel):", ' """', f" {schema.description}", " ", f" Generated from {schema.task_type}.yaml", ' """', " ", " # Bundle metadata", ' bundle_type: Literal["{}"] = "{}"'.format( schema.task_type, schema.task_type ), " created_at: datetime = Field(default_factory=datetime.now)", " bundle_id: Optional[str] = None", " checksum: Optional[str] = None", " ", ] # Add required assets lines.append(" # Required assets") for asset_name in schema.required_assets: if asset_name in schema.asset_definitions: asset_def = schema.asset_definitions[asset_name] lines.append(self.generate_asset_model(asset_name, asset_def)) lines.append("") # Add optional assets if schema.optional_assets: lines.append(" # Optional assets") for asset_name in schema.optional_assets: if asset_name in schema.asset_definitions: asset_def = schema.asset_definitions[asset_name] field_line = self.generate_asset_model(asset_name, asset_def) # Make it optional - wrap type with Optional[...] and add default=None if ": " in field_line and " = " in field_line: # Parse the field line: " field_name: Type = Field(...)" parts = field_line.split(": ", 1) if len(parts) == 2: field_name_part = parts[0] # " field_name" type_and_field_part = parts[ 1 ] # "Type = Field(...)" or "Type = ..." if " = " in type_and_field_part: type_part, field_part = type_and_field_part.split( " = ", 1 ) # For optional assets, wrap type with Optional[...] for mypy compliance if not type_part.startswith("Optional["): type_part = f"Optional[{type_part}]" # Always use Field(default=None, ...) for optional fields if field_part == "...": field_part = "Field(default=None)" elif field_part.startswith("Field("): # Insert default=None as the first argument if not present if "default=None" not in field_part: field_part = field_part.replace( "Field(", "Field(default=None, ", 1 ) else: # If it's just a value, set to None field_part = "Field(default=None)" field_line = ( f"{field_name_part}: {type_part} = {field_part}" ) lines.append(field_line) lines.append("") # Add custom validators validator_lines = self.generate_validators(schema) lines.extend(validator_lines) # Add utility methods lines.extend( [ " def calculate_checksum(self) -> str:", ' """Calculate checksum of bundle contents."""', " # Implementation would hash all asset contents", " import json", " content = self.dict(exclude={'checksum', 'created_at', 'bundle_id'})", " content_str = json.dumps(content, sort_keys=True, default=str)", " return hashlib.sha256(content_str.encode()).hexdigest()[:16]", "", " def validate_dependencies(self, available_bundles: Dict[str, 'BaseModel']) -> bool:", ' """Validate that all dependencies are satisfied."""', " required_deps: List[str] = {} or []".format( schema.dependencies ), " for dep in required_deps:", " if dep not in available_bundles:", " raise ValueError(f'Missing dependency: {dep}')", " return True", "", ] ) return "\n".join(lines)
[docs] def generate_all_models(self) -> str: """Generate Pydantic models for all loaded schemas.""" if not self.schemas: self.load_all_schemas() # Generate imports imports = [ "# Auto-generated Pydantic models from YAML schemas", "# DO NOT EDIT - regenerate using SchemaGenerator", "", "from typing import Dict, List, Optional, Any, Literal", "from pydantic import BaseModel, Field, field_validator", "from datetime import datetime", "import hashlib", "", ] # Generate each bundle model models = [] for schema in self.schemas.values(): model_code = self.generate_bundle_model(schema) models.append(model_code) return "\n".join(imports + models)
[docs] def write_generated_models(self, output_file: Path) -> None: """Write generated models to a Python file.""" generated_code = self.generate_all_models() with open(output_file, "w") as f: f.write(generated_code) print(f"Generated Pydantic models written to: {output_file}")
[docs]def main() -> int: """Main function to auto-generate models from default directories.""" import argparse # Set up argument parser parser = argparse.ArgumentParser( description="Generate Pydantic models from YAML schema definitions" ) parser.add_argument( "--schemas-dir", type=Path, default=Path("src/agentbx/schemas/definitions"), help="Directory containing YAML schema files (default: src/agentbx/schemas/definitions)", ) parser.add_argument( "--output", type=Path, default=Path("src/agentbx/schemas/generated.py"), help="Output file for generated Pydantic models (default: src/agentbx/schemas/generated.py)", ) parser.add_argument( "--watch", action="store_true", help="Watch for changes and auto-regenerate" ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") args = parser.parse_args() # pragma: no cover # Check if directories exist if not args.schemas_dir.exists(): print(f"❌ Schema directory not found: {args.schemas_dir}") print(f"💡 Create it with: mkdir -p {args.schemas_dir}") return 1 # pragma: no cover # Ensure output directory exists args.output.parent.mkdir(parents=True, exist_ok=True) print("🏭 AgentBx Schema Generator") print("=" * 40) print(f"📂 Schemas directory: {args.schemas_dir}") print(f"📄 Output file: {args.output}") # Generate schemas generator = SchemaGenerator(args.schemas_dir) try: # Load all schemas if args.verbose: print("\n📖 Loading YAML schemas...") generator.load_all_schemas() if not generator.schemas: print("⚠️ No YAML schema files found!") print(f" Add .yaml files to: {args.schemas_dir}") return 1 # pragma: no cover # Show loaded schemas print(f"\n✅ Loaded {len(generator.schemas)} schemas:") for schema_name, schema in generator.schemas.items(): required_count = len(schema.required_assets) optional_count = len(schema.optional_assets or []) print( " - {} ({}, {} optional assets)".format( schema_name, required_count, optional_count ) ) # Generate and write models if args.verbose: print("\n🔧 Generating Pydantic models...") generator.write_generated_models(args.output) print(f"\n🎉 Successfully generated: {args.output}") print(f"📊 Generated {len(generator.schemas)} bundle classes") # Show usage hint print("\n💡 Usage:") print(f" from {args.output.stem} import XrayAtomicModelDataBundle") if args.watch: watch_for_changes(generator, args.schemas_dir, args.output, args.verbose) return 0 except Exception as e: print(f"❌ Error generating schemas: {e}") if args.verbose: import traceback traceback.print_exc() # pragma: no cover return 1 # pragma: no cover
[docs]def watch_for_changes( generator: Any, schemas_dir: Path, output_file: Path, verbose: bool = False ) -> None: """Watch for changes in schema files and auto-regenerate.""" try: import time from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer print(f"\n👀 Watching {schemas_dir} for changes... (Ctrl+C to stop)") class SchemaChangeHandler(FileSystemEventHandler): def __init__(self, generator, output_file, verbose): self.generator = generator self.output_file = output_file self.verbose = verbose self.last_regeneration = 0 def on_modified(self, event): if not event.is_directory and event.src_path.endswith(".yaml"): # Debounce rapid changes now = time.time() if now - self.last_regeneration < 1.0: return self.last_regeneration = now print( "\n🔄 {} changed, regenerating...".format( Path(event.src_path).name ) ) try: self.generator.schemas.clear() self.generator.load_all_schemas() self.generator.write_generated_models(self.output_file) print("✅ Regenerated successfully!") except Exception as e: print(f"❌ Regeneration failed: {e}") event_handler = SchemaChangeHandler(generator, output_file, verbose) observer = Observer() observer.schedule(event_handler, str(schemas_dir), recursive=True) observer.start() try: while True: time.sleep(1) except KeyboardInterrupt: print("\n🛑 Stopping file watcher...") observer.stop() observer.join() except ImportError: print( "⚠️ Install watchdog for file watching: pip install watchdog" ) # pragma: no cover
[docs]def quick_generate() -> None: """Quick generation using default paths.""" schema_dir = Path("src/agentbx/schemas/definitions") output_file = Path("src/agentbx/schemas/generated.py") if not schema_dir.exists(): raise FileNotFoundError(f"Schema directory not found: {schema_dir}") generator = SchemaGenerator(schema_dir) generator.write_generated_models(output_file) print(f"✅ Generated {len(generator.schemas)} schemas → {output_file}")
if __name__ == "__main__": exit(main()) # pragma: no cover