Skip to content

Model Structure (Advanced)

For Advanced Users

This section documents the internal structure of the Model class and its components.

**Most users don't need this** - use `ModelBuilder` instead to create models.

This is useful for:
- Understanding the internal model representation
- Working with `ModelLoader.from_json()`
- Contributing to the library
- Debugging complex models

Model

Model

Bases: BaseModel

Root class of epidemiological model.

Attributes:

Name Type Description
name str

A unique name that identifies the model.

description str | None

A human-readable description of the model's purpose and function.

version str | None

The version number of the model.

population Population

Population details, subpopulations, stratifications and initial conditions.

parameters list[Parameter]

A list of global model parameters.

dynamics Dynamics

The rules that govern system evolution.

Source code in epimodel/context/model.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
class Model(BaseModel):
    """
    Root class of epidemiological model.

    Attributes
    ----------
    name : str
        A unique name that identifies the model.
    description : str | None
        A human-readable description of the model's purpose and function.
    version : str | None
        The version number of the model.
    population : Population
        Population details, subpopulations, stratifications and initial conditions.
    parameters : list[Parameter]
        A list of global model parameters.
    dynamics : Dynamics
        The rules that govern system evolution.
    """

    name: str = Field(..., description="Name which identifies the model.")
    description: str | None = Field(
        None,
        description="Human-readable description of the model's purpose and function.",
    )
    version: str | None = Field(None, description="Version number of the model.")

    population: Population
    parameters: list[Parameter]
    dynamics: Dynamics

    @model_validator(mode="after")
    def validate_unique_parameter_ids(self) -> Self:
        """
        Validates that parameter IDs are unique.
        """
        parameter_ids = [p.id for p in self.parameters]
        if len(parameter_ids) != len(set(parameter_ids)):
            duplicates = [
                item for item in set(parameter_ids) if parameter_ids.count(item) > 1
            ]
            raise ValueError(f"Duplicate parameter IDs found: {duplicates}")
        return self

    @model_validator(mode="after")
    def validate_formula_variables(self) -> Self:
        """
        Validate that all variables in rate expressions are defined.
        This is done by gathering all valid identifiers and checking each
        transition's rate expressions against them.
        """
        valid_identifiers = self._get_valid_identifiers()

        for transition in self.dynamics.transitions:
            self._validate_transition_rates(transition, valid_identifiers)
        return self

    def _get_valid_identifiers(self) -> set[str]:
        """Gathers all valid identifiers for use in rate expressions."""
        special_vars = {"N", "step", "pi", "e", "t"}
        param_ids = {param.id for param in self.parameters}
        state_ids = {state.id for state in self.population.disease_states}

        strat_category_ids: set[str] = {
            cat for strat in self.population.stratifications for cat in strat.categories
        }

        subpopulation_n_vars = self._get_subpopulation_n_vars()

        return (
            param_ids
            | state_ids
            | strat_category_ids
            | special_vars
            | subpopulation_n_vars
        )

    def _get_subpopulation_n_vars(self) -> set[str]:
        """Generates all possible N_{category...} variable names."""
        if not self.population.stratifications:
            return set()

        subpopulation_n_vars: set[str] = set()
        category_groups = [s.categories for s in self.population.stratifications]

        # All possible combinations of categories across different stratifications
        full_category_combos = product(*category_groups)

        for combo_tuple in full_category_combos:
            # For each combo, find all non-empty subsets
            for i in range(1, len(combo_tuple) + 1):
                for subset in combinations(combo_tuple, i):
                    var_name = f"N_{'_'.join(subset)}"
                    subpopulation_n_vars.add(var_name)

        return subpopulation_n_vars

    def _validate_transition_rates(
        self, transition: Transition, valid_identifiers: set[str]
    ):
        """Validates the rate expressions for a single transition."""
        if transition.rate:
            self._validate_rate_expression(
                transition.rate, transition.id, "rate", valid_identifiers
            )

        if transition.stratified_rates:
            for sr in transition.stratified_rates:
                self._validate_rate_expression(
                    sr.rate, transition.id, "stratified_rate", valid_identifiers
                )

    def _validate_rate_expression(
        self, rate: str, transition_id: str, context: str, valid_identifiers: set[str]
    ):
        """Validates variables in a single rate expression."""
        variables = get_expression_variables(rate)
        undefined_vars = [var for var in variables if var not in valid_identifiers]
        if undefined_vars:
            param_ids = {param.id for param in self.parameters}
            state_ids = {state.id for state in self.population.disease_states}
            raise ValueError(
                (
                    f"Undefined variables in transition '{transition_id}' "
                    f"{context} '{rate}': {', '.join(undefined_vars)}. "
                    f"Available parameters: "
                    f"{', '.join(sorted(param_ids)) if param_ids else 'none'}. "
                    f"Available disease states: "
                    f"{', '.join(sorted(state_ids)) if state_ids else 'none'}."
                )
            )

    @model_validator(mode="after")
    def validate_transition_ids(self) -> Self:
        """
        Validates that transition ids (source/target) are consistent in type
        and match the defined DiseaseState IDs or Stratification Categories
        in the Population instance.
        """

        disease_state_ids = {state.id for state in self.population.disease_states}
        categories_ids = {
            cat for strat in self.population.stratifications for cat in strat.categories
        }
        disease_state_and_categories_ids = disease_state_ids.union(categories_ids)

        for transition in self.dynamics.transitions:
            source = set(transition.source)
            target = set(transition.target)
            transition_ids = source.union(target)

            if not transition_ids.issubset(disease_state_and_categories_ids):
                invalid_ids = transition_ids - disease_state_and_categories_ids
                raise ValueError(
                    (
                        f"Transition '{transition.id}' contains invalid ids: "
                        f"{invalid_ids}. Ids must be defined in DiseaseState ids "
                        f"or Stratification Categories."
                    )
                )

            is_disease_state_flow = transition_ids.issubset(disease_state_ids)
            is_stratification_flow = transition_ids.issubset(categories_ids)

            if (not is_disease_state_flow) and (not is_stratification_flow):
                disease_state_elements = transition_ids.intersection(disease_state_ids)
                categories_elements = transition_ids.intersection(categories_ids)
                raise ValueError(
                    (
                        f"Transition '{transition.id}' mixes id types. "
                        f"Found DiseaseState ids ({disease_state_elements}) and "
                        f"Stratification Categories ids ({categories_elements}). "
                        "Transitions must be purely Disease State flow or purely "
                        f"Stratification flow."
                    )
                )

            if is_stratification_flow:
                category_to_stratification_map = {
                    cat: strat.id
                    for strat in self.population.stratifications
                    for cat in strat.categories
                }
                parent_stratification_ids = {
                    category_to_stratification_map[cat_id] for cat_id in transition_ids
                }
                if len(parent_stratification_ids) > 1:
                    mixed_strats = ", ".join(parent_stratification_ids)
                    raise ValueError(
                        (
                            f"Transition '{transition.id}' is a Stratification flow "
                            f"but involves categories from multiple stratifications: "
                            f"{mixed_strats}. A single transition must only move "
                            f"between categories belonging to the same parent "
                            f"stratification."
                        )
                    )

        return self

    def print_equations(self, output_file: str | None = None) -> None:
        """
        Prints the difference equations of the model in mathematical form.

        Displays model metadata and the system of difference equations in both
        compact (mathematical notation) and expanded (individual equations) forms.

        Parameters
        ----------
        output_file : str | None
            If provided, writes the equations to this file path instead of printing
            to console. If None, prints to console.

        Raises
        ------
        ValueError
            If the model is not a DifferenceEquations model.
        """

        if self.dynamics.typology != ModelTypes.DIFFERENCE_EQUATIONS:
            raise ValueError(
                (
                    f"print_equations only supports DifferenceEquations models. "
                    f"Current model type: {self.dynamics.typology}"
                )
            )

        lines = self._generate_model_header()

        # Check if model has stratifications
        has_stratifications = len(self.population.stratifications) > 0

        if has_stratifications:
            # Enhanced output for stratified models
            lines.extend(self._generate_compact_form())
            lines.append("")
            lines.extend(self._generate_expanded_form())
        else:
            # Simple output for non-stratified models
            lines.extend(self._generate_expanded_form())

        output = "\n".join(lines)
        self._write_output(output, output_file)

    def _generate_model_header(self) -> list[str]:
        """Generate the header lines with model metadata."""
        lines: list[str] = []
        lines.append("=" * 40)
        lines.append("MODEL INFORMATION")
        lines.append("=" * 40)
        lines.append(f"Model: {self.name}")
        lines.append(f"Model Type: {self.dynamics.typology}")
        lines.append(f"Number of Disease States: {len(self.population.disease_states)}")
        lines.append(
            f"Number of Stratifications: {len(self.population.stratifications)}"
        )
        lines.append(f"Number of Parameters: {len(self.parameters)}")
        lines.append(f"Number of Transitions: {len(self.dynamics.transitions)}")

        # List disease states
        disease_state_ids = [state.id for state in self.population.disease_states]
        lines.append(f"Disease States: {', '.join(disease_state_ids)}")

        # List stratifications
        if self.population.stratifications:
            lines.append("Stratifications:")
            for strat in self.population.stratifications:
                categories = ", ".join(strat.categories)
                lines.append(f"  - {strat.id}: [{categories}]")

        lines.append("")
        return lines

    def _collect_state_ids(self) -> set[str]:
        """Collect all state IDs from disease states and stratifications."""
        state_ids = {state.id for state in self.population.disease_states}
        for strat in self.population.stratifications:
            state_ids.update(strat.categories)
        return state_ids

    def _build_flow_equations(
        self, state_ids: set[str]
    ) -> dict[str, dict[str, list[str]]]:
        """Build a mapping of states to their inflows and outflows."""
        equations: dict[str, dict[str, list[str]]] = {
            state_id: {"inflows": [], "outflows": []} for state_id in state_ids
        }
        for transition in self.dynamics.transitions:
            rate = transition.rate if transition.rate else ""
            source_counts = {
                state: transition.source.count(state)
                for state in set(transition.source)
            }
            target_counts = {
                state: transition.target.count(state)
                for state in set(transition.target)
            }
            all_states = set(transition.source) | set(transition.target)
            for state in all_states:
                net_change = target_counts.get(state, 0) - source_counts.get(state, 0)
                if net_change > 0:
                    equations[state]["inflows"].append(rate)
                elif net_change < 0:
                    equations[state]["outflows"].append(rate)

        return equations

    def _format_state_equation(self, flows: dict[str, list[str]]) -> str:
        """Format the equation for a single state from its flows."""
        terms: list[str] = []

        for inflow in flows["inflows"]:
            terms.append(f"+ ({inflow})" if inflow else "+ ()")

        for outflow in flows["outflows"]:
            terms.append(f"- ({outflow})" if outflow else "- ()")

        if not terms:
            return ""

        result = " ".join(terms)
        # Remove leading + sign if present
        if result.startswith("+"):
            result = result[1:]
        return result

    def _generate_compact_form(self) -> list[str]:
        """Generate compact mathematical notation form for stratified models."""
        lines: list[str] = []
        lines.append("=" * 40)
        lines.append("COMPACT FORM")
        lines.append("=" * 40)
        lines.append("")

        disease_state_ids = [state.id for state in self.population.disease_states]
        disease_transitions, stratification_transitions = (
            self._separate_transitions_by_type()
        )

        lines.extend(
            self._format_stratification_transitions_compact(
                disease_state_ids, stratification_transitions
            )
        )
        lines.extend(
            self._format_disease_transitions_compact(
                disease_state_ids, disease_transitions
            )
        )
        lines.extend(self._format_total_system_size(disease_state_ids))

        return lines

    def _separate_transitions_by_type(
        self,
    ) -> tuple[list[Transition], list[Transition]]:
        """Separate transitions into disease and stratification types."""
        disease_state_ids = [state.id for state in self.population.disease_states]
        disease_state_set = set(disease_state_ids)

        disease_transitions: list[Transition] = []
        stratification_transitions: list[Transition] = []

        for transition in self.dynamics.transitions:
            transition_states = set(transition.source) | set(transition.target)
            if transition_states.issubset(disease_state_set):
                disease_transitions.append(transition)
            else:
                stratification_transitions.append(transition)

        return disease_transitions, stratification_transitions

    def _format_stratification_transitions_compact(
        self, disease_state_ids: list[str], stratification_transitions: list[Transition]
    ) -> list[str]:
        """Format stratification transitions in compact form."""
        lines: list[str] = []
        strat_by_id = self._group_transitions_by_stratification(
            stratification_transitions
        )

        for strat in self.population.stratifications:
            if strat_by_id[strat.id]:
                lines.append(f"Stratification Transitions ({strat.id}):")
                disease_states_str = ", ".join(disease_state_ids)
                lines.append(f"For each disease state X in {{{disease_states_str}}}:")

                for category in strat.categories:
                    equation = self._build_category_equation(
                        category, strat_by_id[strat.id]
                    )
                    if equation:
                        lines.append(f"  dX_{category}/dt: {equation}")

                lines.append("")

        return lines

    def _group_transitions_by_stratification(
        self, transitions: list[Transition]
    ) -> dict[str, list[Transition]]:
        """Group stratification transitions by their stratification ID."""
        strat_by_id: dict[str, list[Transition]] = {}
        for strat in self.population.stratifications:
            strat_by_id[strat.id] = []
            for transition in transitions:
                transition_states = set(transition.source) | set(transition.target)
                if transition_states.issubset(set(strat.categories)):
                    strat_by_id[strat.id].append(transition)
        return strat_by_id

    def _build_category_equation(
        self, category: str, transitions: list[Transition]
    ) -> str:
        """Build equation for a stratification category."""
        inflows: list[str] = []
        outflows: list[str] = []

        for transition in transitions:
            if not transition.rate:
                continue

            source_count = transition.source.count(category)
            target_count = transition.target.count(category)
            net_change = target_count - source_count

            if net_change > 0:
                inflows.append(f"+ ({transition.rate} * X)")
            elif net_change < 0:
                outflows.append(f"- ({transition.rate} * X)")

        terms = inflows + outflows
        if not terms:
            return ""

        result = " ".join(terms)
        # Remove leading + sign if present
        if result.startswith("+"):
            result = result[1:]
        return result

    def _format_disease_transitions_compact(
        self, disease_state_ids: list[str], disease_transitions: list[Transition]
    ) -> list[str]:
        """Format disease state transitions in compact form."""
        lines: list[str] = []

        if not disease_transitions:
            return lines

        lines.append("Disease State Transitions:")

        if self.population.stratifications:
            all_categories = [
                cat
                for strat in self.population.stratifications
                for cat in strat.categories
            ]
            categories_str = ", ".join(all_categories)
            lines.append(f"For each stratification s in {{{categories_str}}}:")

        for disease_state in disease_state_ids:
            equation = self._build_disease_state_equation(
                disease_state, disease_transitions
            )
            if equation:
                suffix = "_s" if self.population.stratifications else ""
                lines.append(f"  d{disease_state}{suffix}/dt: {equation}")

        lines.append("")
        return lines

    def _build_disease_state_equation(
        self, disease_state: str, transitions: list[Transition]
    ) -> str:
        """Build equation for a disease state."""
        inflows: list[str] = []
        outflows: list[str] = []

        for transition in transitions:
            if not transition.rate:
                continue

            source_count = transition.source.count(disease_state)
            target_count = transition.target.count(disease_state)
            net_change = target_count - source_count

            if net_change > 0:
                inflows.append(f"+ ({transition.rate})")
            elif net_change < 0:
                outflows.append(f"- ({transition.rate})")

        terms = inflows + outflows
        if not terms:
            return ""

        result = " ".join(terms)
        # Remove leading + sign if present
        if result.startswith("+"):
            result = result[1:]
        return result

    def _format_total_system_size(self, disease_state_ids: list[str]) -> list[str]:
        """Format the total system size information."""
        lines: list[str] = []

        num_disease_states = len(disease_state_ids)
        num_strat_combinations = 1
        for strat in self.population.stratifications:
            num_strat_combinations *= len(strat.categories)
        total_equations = num_disease_states * num_strat_combinations

        lines.append(
            (
                f"Total System: {total_equations} coupled equations "
                f"({num_disease_states} disease states × {num_strat_combinations} "
                f"stratification)"
            )
        )

        return lines

    def _generate_expanded_form(self) -> list[str]:
        """Generate expanded form with individual equations."""
        lines: list[str] = []

        has_stratifications = len(self.population.stratifications) > 0

        if has_stratifications:
            lines.append("=" * 40)
            lines.append("EXPANDED FORM")
            lines.append("=" * 40)
        else:
            lines.append("=" * 40)
            lines.append("EQUATIONS")
            lines.append("=" * 40)

        state_ids = self._collect_state_ids()
        equations = self._build_flow_equations(state_ids)

        # Order: disease states first (in order),
        # then stratification categories (in order)
        disease_state_ids = [state.id for state in self.population.disease_states]
        stratification_category_ids = [
            cat for strat in self.population.stratifications for cat in strat.categories
        ]
        ordered_state_ids = disease_state_ids + stratification_category_ids

        for state_id in ordered_state_ids:
            equation = self._format_state_equation(equations[state_id])
            lines.append(f"d{state_id}/dt = {equation}")

        return lines

    def _write_output(self, output: str, output_file: str | None) -> None:
        """Write output to file or console."""
        if output_file:
            with open(output_file, "w") as f:
                _ = f.write(output)
        else:
            print(output)

Functions

validate_unique_parameter_ids

validate_unique_parameter_ids() -> Self

Validates that parameter IDs are unique.

Source code in epimodel/context/model.py
@model_validator(mode="after")
def validate_unique_parameter_ids(self) -> Self:
    """
    Validates that parameter IDs are unique.
    """
    parameter_ids = [p.id for p in self.parameters]
    if len(parameter_ids) != len(set(parameter_ids)):
        duplicates = [
            item for item in set(parameter_ids) if parameter_ids.count(item) > 1
        ]
        raise ValueError(f"Duplicate parameter IDs found: {duplicates}")
    return self

validate_formula_variables

validate_formula_variables() -> Self

Validate that all variables in rate expressions are defined. This is done by gathering all valid identifiers and checking each transition's rate expressions against them.

Source code in epimodel/context/model.py
@model_validator(mode="after")
def validate_formula_variables(self) -> Self:
    """
    Validate that all variables in rate expressions are defined.
    This is done by gathering all valid identifiers and checking each
    transition's rate expressions against them.
    """
    valid_identifiers = self._get_valid_identifiers()

    for transition in self.dynamics.transitions:
        self._validate_transition_rates(transition, valid_identifiers)
    return self

validate_transition_ids

validate_transition_ids() -> Self

Validates that transition ids (source/target) are consistent in type and match the defined DiseaseState IDs or Stratification Categories in the Population instance.

Source code in epimodel/context/model.py
@model_validator(mode="after")
def validate_transition_ids(self) -> Self:
    """
    Validates that transition ids (source/target) are consistent in type
    and match the defined DiseaseState IDs or Stratification Categories
    in the Population instance.
    """

    disease_state_ids = {state.id for state in self.population.disease_states}
    categories_ids = {
        cat for strat in self.population.stratifications for cat in strat.categories
    }
    disease_state_and_categories_ids = disease_state_ids.union(categories_ids)

    for transition in self.dynamics.transitions:
        source = set(transition.source)
        target = set(transition.target)
        transition_ids = source.union(target)

        if not transition_ids.issubset(disease_state_and_categories_ids):
            invalid_ids = transition_ids - disease_state_and_categories_ids
            raise ValueError(
                (
                    f"Transition '{transition.id}' contains invalid ids: "
                    f"{invalid_ids}. Ids must be defined in DiseaseState ids "
                    f"or Stratification Categories."
                )
            )

        is_disease_state_flow = transition_ids.issubset(disease_state_ids)
        is_stratification_flow = transition_ids.issubset(categories_ids)

        if (not is_disease_state_flow) and (not is_stratification_flow):
            disease_state_elements = transition_ids.intersection(disease_state_ids)
            categories_elements = transition_ids.intersection(categories_ids)
            raise ValueError(
                (
                    f"Transition '{transition.id}' mixes id types. "
                    f"Found DiseaseState ids ({disease_state_elements}) and "
                    f"Stratification Categories ids ({categories_elements}). "
                    "Transitions must be purely Disease State flow or purely "
                    f"Stratification flow."
                )
            )

        if is_stratification_flow:
            category_to_stratification_map = {
                cat: strat.id
                for strat in self.population.stratifications
                for cat in strat.categories
            }
            parent_stratification_ids = {
                category_to_stratification_map[cat_id] for cat_id in transition_ids
            }
            if len(parent_stratification_ids) > 1:
                mixed_strats = ", ".join(parent_stratification_ids)
                raise ValueError(
                    (
                        f"Transition '{transition.id}' is a Stratification flow "
                        f"but involves categories from multiple stratifications: "
                        f"{mixed_strats}. A single transition must only move "
                        f"between categories belonging to the same parent "
                        f"stratification."
                    )
                )

    return self

print_equations

print_equations(output_file: str | None = None) -> None

Prints the difference equations of the model in mathematical form.

Displays model metadata and the system of difference equations in both compact (mathematical notation) and expanded (individual equations) forms.

Parameters:

Name Type Description Default
output_file str | None

If provided, writes the equations to this file path instead of printing to console. If None, prints to console.

None

Raises:

Type Description
ValueError

If the model is not a DifferenceEquations model.

Source code in epimodel/context/model.py
def print_equations(self, output_file: str | None = None) -> None:
    """
    Prints the difference equations of the model in mathematical form.

    Displays model metadata and the system of difference equations in both
    compact (mathematical notation) and expanded (individual equations) forms.

    Parameters
    ----------
    output_file : str | None
        If provided, writes the equations to this file path instead of printing
        to console. If None, prints to console.

    Raises
    ------
    ValueError
        If the model is not a DifferenceEquations model.
    """

    if self.dynamics.typology != ModelTypes.DIFFERENCE_EQUATIONS:
        raise ValueError(
            (
                f"print_equations only supports DifferenceEquations models. "
                f"Current model type: {self.dynamics.typology}"
            )
        )

    lines = self._generate_model_header()

    # Check if model has stratifications
    has_stratifications = len(self.population.stratifications) > 0

    if has_stratifications:
        # Enhanced output for stratified models
        lines.extend(self._generate_compact_form())
        lines.append("")
        lines.extend(self._generate_expanded_form())
    else:
        # Simple output for non-stratified models
        lines.extend(self._generate_expanded_form())

    output = "\n".join(lines)
    self._write_output(output, output_file)

options: show_root_heading: true show_source: true heading_level: 3

Population

Population

Bases: BaseModel

Defines the compartments, stratifications, and initial conditions of the population.

Attributes:

Name Type Description
disease_states list[DiseaseState]

A list of compartments or states that make up the model.

stratifications list[Stratification]

A list of categorical subdivisions of the population.

initial_conditions Initialization

Initial state of the subpopulations and stratifications.

Source code in epimodel/context/population.py
class Population(BaseModel):
    """
    Defines the compartments, stratifications, and initial conditions of the population.

    Attributes
    ----------
    disease_states : list[DiseaseState]
        A list of compartments or states that make up the model.
    stratifications : list[Stratification]
        A list of categorical subdivisions of the population.
    initial_conditions: Initialization
        Initial state of the subpopulations and stratifications.
    """

    disease_states: list[DiseaseState]
    stratifications: list[Stratification]
    transitions: list[Transition]
    initial_conditions: InitialConditions

    @field_validator("disease_states")
    @classmethod
    def validate_disease_states_not_empty(
        cls, v: list[DiseaseState]
    ) -> list[DiseaseState]:
        if not v:
            raise ValueError("At least one disease state must be defined.")
        return v

    @model_validator(mode="after")
    def validate_unique_ids(self) -> Self:
        """
        Validates that disease state and stratification IDs are unique.
        """
        disease_state_ids = [ds.id for ds in self.disease_states]
        if len(disease_state_ids) != len(set(disease_state_ids)):
            duplicates = [
                item
                for item in set(disease_state_ids)
                if disease_state_ids.count(item) > 1
            ]
            raise ValueError(f"Duplicate disease state IDs found: {duplicates}")

        stratification_ids = [s.id for s in self.stratifications]
        if len(stratification_ids) != len(set(stratification_ids)):
            duplicates = [
                item
                for item in set(stratification_ids)
                if stratification_ids.count(item) > 1
            ]
            raise ValueError(f"Duplicate stratification IDs found: {duplicates}")

        return self

    @model_validator(mode="after")
    def validate_disease_state_initial_conditions(self) -> Self:
        """
        Validates initial conditions against the defined model Subpopulation.
        """
        initial_conditions = self.initial_conditions

        disease_states_map = {state.id: state for state in self.disease_states}

        disease_state_fractions_dict = {
            dsf.disease_state: dsf.fraction
            for dsf in initial_conditions.disease_state_fractions
        }

        actual_state = set(disease_state_fractions_dict.keys())
        expected_state = set(disease_states_map.keys())

        if actual_state != expected_state:
            missing = expected_state - actual_state
            extra = actual_state - expected_state
            raise ValueError(
                (
                    f"Initial disease state fractions keys must exactly match "
                    f"disease state ids. Missing ids: {missing}, Extra ids: {extra}."
                )
            )

        states_sum_fractions = sum(disease_state_fractions_dict.values())
        if not math.isclose(states_sum_fractions, 1.0, abs_tol=1e-6):
            raise ValueError(
                (
                    f"Disease state fractions must sum to 1.0, "
                    f"but got {states_sum_fractions:.7f}."
                )
            )

        return self

    @model_validator(mode="after")
    def validate_stratified_rates(self) -> Self:
        """
        Validates that stratified rates reference existing stratifications and
        categories.
        """
        strat_map = {strat.id: strat for strat in self.stratifications}

        for transition in self.transitions:
            if transition.stratified_rates:
                for idx, stratified_rate in enumerate(transition.stratified_rates):
                    for condition in stratified_rate.conditions:
                        # Validate stratification exists
                        if condition.stratification not in strat_map:
                            raise ValueError(
                                (
                                    f"In transition '{transition.id}', stratified rate "
                                    f"{idx}: Stratification "
                                    f"'{condition.stratification}' not found. "
                                    f"Available: {list(strat_map.keys())}"
                                )
                            )

                        # Validate category exists
                        strat = strat_map[condition.stratification]
                        if condition.category not in strat.categories:
                            raise ValueError(
                                (
                                    f"In transition '{transition.id}', stratified rate "
                                    f"{idx}: Category '{condition.category}' not found "
                                    f"in stratification '{condition.stratification}'. "
                                    f"Available: {strat.categories}"
                                )
                            )

        return self

    @model_validator(mode="after")
    def validate_stratification_initial_conditions(self) -> Self:
        """
        Validates initial conditions against the defined model Stratification.
        """
        initial_conditions = self.initial_conditions

        strat_map = {strat.id: strat for strat in self.stratifications}

        actual_strat = {
            sf.stratification for sf in initial_conditions.stratification_fractions
        }
        expected_strat = set(strat_map.keys())

        if actual_strat != expected_strat:
            missing = expected_strat - actual_strat
            extra = actual_strat - expected_strat
            raise ValueError(
                (
                    f"Initial stratification fractions keys must exactly match "
                    f"stratification ids. Missing ids: {missing}, Extra ids: {extra}."
                )
            )

        for strat_fractions in initial_conditions.stratification_fractions:
            strat_id = strat_fractions.stratification
            strat_instance = strat_map[strat_id]

            fractions_dict = {
                sf.category: sf.fraction for sf in strat_fractions.fractions
            }

            categories_expected = set(strat_instance.categories)
            categories_actual = set(fractions_dict.keys())

            if categories_actual != categories_expected:
                missing = categories_expected - categories_actual
                extra = categories_actual - categories_expected
                raise ValueError(
                    (
                        f"Categories for stratification '{strat_id}' must exactly "
                        f"match defined categories in instance '{strat_instance.id}'. "
                        f"Missing categories: {missing}, Extra categories: {extra}."
                    )
                )

            strat_sum_fractions = sum(fractions_dict.values())
            if not math.isclose(strat_sum_fractions, 1.0, abs_tol=1e-6):
                raise ValueError(
                    (
                        f"Stratification fractions for '{strat_id}' must sum to 1.0, "
                        f"but got {strat_sum_fractions:.7}."
                    )
                )

        return self

Functions

validate_unique_ids

validate_unique_ids() -> Self

Validates that disease state and stratification IDs are unique.

Source code in epimodel/context/population.py
@model_validator(mode="after")
def validate_unique_ids(self) -> Self:
    """
    Validates that disease state and stratification IDs are unique.
    """
    disease_state_ids = [ds.id for ds in self.disease_states]
    if len(disease_state_ids) != len(set(disease_state_ids)):
        duplicates = [
            item
            for item in set(disease_state_ids)
            if disease_state_ids.count(item) > 1
        ]
        raise ValueError(f"Duplicate disease state IDs found: {duplicates}")

    stratification_ids = [s.id for s in self.stratifications]
    if len(stratification_ids) != len(set(stratification_ids)):
        duplicates = [
            item
            for item in set(stratification_ids)
            if stratification_ids.count(item) > 1
        ]
        raise ValueError(f"Duplicate stratification IDs found: {duplicates}")

    return self

validate_disease_state_initial_conditions

validate_disease_state_initial_conditions() -> Self

Validates initial conditions against the defined model Subpopulation.

Source code in epimodel/context/population.py
@model_validator(mode="after")
def validate_disease_state_initial_conditions(self) -> Self:
    """
    Validates initial conditions against the defined model Subpopulation.
    """
    initial_conditions = self.initial_conditions

    disease_states_map = {state.id: state for state in self.disease_states}

    disease_state_fractions_dict = {
        dsf.disease_state: dsf.fraction
        for dsf in initial_conditions.disease_state_fractions
    }

    actual_state = set(disease_state_fractions_dict.keys())
    expected_state = set(disease_states_map.keys())

    if actual_state != expected_state:
        missing = expected_state - actual_state
        extra = actual_state - expected_state
        raise ValueError(
            (
                f"Initial disease state fractions keys must exactly match "
                f"disease state ids. Missing ids: {missing}, Extra ids: {extra}."
            )
        )

    states_sum_fractions = sum(disease_state_fractions_dict.values())
    if not math.isclose(states_sum_fractions, 1.0, abs_tol=1e-6):
        raise ValueError(
            (
                f"Disease state fractions must sum to 1.0, "
                f"but got {states_sum_fractions:.7f}."
            )
        )

    return self

validate_stratified_rates

validate_stratified_rates() -> Self

Validates that stratified rates reference existing stratifications and categories.

Source code in epimodel/context/population.py
@model_validator(mode="after")
def validate_stratified_rates(self) -> Self:
    """
    Validates that stratified rates reference existing stratifications and
    categories.
    """
    strat_map = {strat.id: strat for strat in self.stratifications}

    for transition in self.transitions:
        if transition.stratified_rates:
            for idx, stratified_rate in enumerate(transition.stratified_rates):
                for condition in stratified_rate.conditions:
                    # Validate stratification exists
                    if condition.stratification not in strat_map:
                        raise ValueError(
                            (
                                f"In transition '{transition.id}', stratified rate "
                                f"{idx}: Stratification "
                                f"'{condition.stratification}' not found. "
                                f"Available: {list(strat_map.keys())}"
                            )
                        )

                    # Validate category exists
                    strat = strat_map[condition.stratification]
                    if condition.category not in strat.categories:
                        raise ValueError(
                            (
                                f"In transition '{transition.id}', stratified rate "
                                f"{idx}: Category '{condition.category}' not found "
                                f"in stratification '{condition.stratification}'. "
                                f"Available: {strat.categories}"
                            )
                        )

    return self

validate_stratification_initial_conditions

validate_stratification_initial_conditions() -> Self

Validates initial conditions against the defined model Stratification.

Source code in epimodel/context/population.py
@model_validator(mode="after")
def validate_stratification_initial_conditions(self) -> Self:
    """
    Validates initial conditions against the defined model Stratification.
    """
    initial_conditions = self.initial_conditions

    strat_map = {strat.id: strat for strat in self.stratifications}

    actual_strat = {
        sf.stratification for sf in initial_conditions.stratification_fractions
    }
    expected_strat = set(strat_map.keys())

    if actual_strat != expected_strat:
        missing = expected_strat - actual_strat
        extra = actual_strat - expected_strat
        raise ValueError(
            (
                f"Initial stratification fractions keys must exactly match "
                f"stratification ids. Missing ids: {missing}, Extra ids: {extra}."
            )
        )

    for strat_fractions in initial_conditions.stratification_fractions:
        strat_id = strat_fractions.stratification
        strat_instance = strat_map[strat_id]

        fractions_dict = {
            sf.category: sf.fraction for sf in strat_fractions.fractions
        }

        categories_expected = set(strat_instance.categories)
        categories_actual = set(fractions_dict.keys())

        if categories_actual != categories_expected:
            missing = categories_expected - categories_actual
            extra = categories_actual - categories_expected
            raise ValueError(
                (
                    f"Categories for stratification '{strat_id}' must exactly "
                    f"match defined categories in instance '{strat_instance.id}'. "
                    f"Missing categories: {missing}, Extra categories: {extra}."
                )
            )

        strat_sum_fractions = sum(fractions_dict.values())
        if not math.isclose(strat_sum_fractions, 1.0, abs_tol=1e-6):
            raise ValueError(
                (
                    f"Stratification fractions for '{strat_id}' must sum to 1.0, "
                    f"but got {strat_sum_fractions:.7}."
                )
            )

    return self

options: show_root_heading: true show_source: true heading_level: 3

Disease States

DiseaseState

Bases: BaseModel

Defines a single disease state of a person regarding the disease.

Attributes:

Name Type Description
id str

Identifier of the disease state.

name str

A descriptive, human-readable name for the disease state.

Source code in epimodel/context/disease_state.py
class DiseaseState(BaseModel):
    """
    Defines a single disease state of a person regarding the disease.

    Attributes
    ----------
    id : str
        Identifier of the disease state.
    name : str
        A descriptive, human-readable name for the disease state.
    """

    id: str = Field(..., description="Identifier of the disease state.")
    name: str = Field(
        ..., description="Descriptive, human-readable name for the disease state."
    )

    @override
    def __hash__(self):
        return hash(self.id)

    @override
    def __eq__(self, other: object):
        return isinstance(other, DiseaseState) and self.id == other.id

options: show_root_heading: true show_source: true heading_level: 3

Stratifications

Stratification

Bases: BaseModel

Defines a categorical subdivision of the population.

Attributes:

Name Type Description
id str

Identifier of the stratification.

categories list[str]

List of the different stratification groups identifiers.

Source code in epimodel/context/stratification.py
class Stratification(BaseModel):
    """
    Defines a categorical subdivision of the population.

    Attributes
    ----------
    id : str
        Identifier of the stratification.
    categories : list[str]
        List of the different stratification groups identifiers.
    """

    id: str = Field(..., description="Identifier of the stratification.")
    categories: list[str] = Field(
        ..., description="List of the different stratification groups identifiers."
    )

    @override
    def __hash__(self):
        return hash(self.id)

    @override
    def __eq__(self, other: object):
        return isinstance(other, Stratification) and self.id == other.id

    @model_validator(mode="after")
    def validate_categories_length(self) -> Self:
        """
        Enforces that categories are not empty.
        """
        if not self.categories:
            raise ValueError(
                (f"Stratification '{self.id}' must have at least one category.")
            )
        return self

    @model_validator(mode="after")
    def validate_categories_uniqueness(self) -> Self:
        """
        Enforces that categories are not repeated.
        """
        categories_set = set(self.categories)

        if len(categories_set) != len(self.categories):
            duplicates = [
                item for item in categories_set if self.categories.count(item) > 1
            ]
            raise ValueError(
                (
                    f"Categories for stratification '{self.id}' must not be repeated. "
                    f"Found duplicates: {list(set(duplicates))}."
                )
            )

        return self

Functions

validate_categories_length

validate_categories_length() -> Self

Enforces that categories are not empty.

Source code in epimodel/context/stratification.py
@model_validator(mode="after")
def validate_categories_length(self) -> Self:
    """
    Enforces that categories are not empty.
    """
    if not self.categories:
        raise ValueError(
            (f"Stratification '{self.id}' must have at least one category.")
        )
    return self

validate_categories_uniqueness

validate_categories_uniqueness() -> Self

Enforces that categories are not repeated.

Source code in epimodel/context/stratification.py
@model_validator(mode="after")
def validate_categories_uniqueness(self) -> Self:
    """
    Enforces that categories are not repeated.
    """
    categories_set = set(self.categories)

    if len(categories_set) != len(self.categories):
        duplicates = [
            item for item in categories_set if self.categories.count(item) > 1
        ]
        raise ValueError(
            (
                f"Categories for stratification '{self.id}' must not be repeated. "
                f"Found duplicates: {list(set(duplicates))}."
            )
        )

    return self

options: show_root_heading: true show_source: true heading_level: 3

Parameters

Parameter

Bases: BaseModel

Defines a global model parameter.

Attributes:

Name Type Description
id str

The identifier of the parameter.

value float

Numerical value of the parameter.

description str | None

A human-readable description of the parameter.

Source code in epimodel/context/parameter.py
class Parameter(BaseModel):
    """
    Defines a global model parameter.

    Attributes
    ----------
    id : str
        The identifier of the parameter.
    value : float
        Numerical value of the parameter.
    description : str | None
        A human-readable description of the parameter.
    """

    id: str = Field(..., description="Identifier of the parameter.")
    value: float = Field(..., description="Numerical value of the parameter.")
    description: str | None = Field(
        None, description="Human-readable description of the parameter."
    )

options: show_root_heading: true show_source: true heading_level: 3

Transitions

Transition

Bases: BaseModel

Defines a rule for system evolution.

Attributes:

Name Type Description
id str

Id of the transition.

source list[str]

The origin compartments.

target list[str]

The destination compartments.

rate str | None

Default mathematical formula, parameter name, or constant value for the flow. Used when no stratified rate matches. Numeric values are automatically converted to strings during validation.

Operators: +, -, *, /, % (modulo), ^ or ** (power) Functions: sin, cos, tan, exp, ln, sqrt, abs, min, max, if, etc. Constants: pi, e

Note: Both ^ and ** are supported for exponentiation (** is converted to ^).

Examples: - "beta" (parameter reference) - "0.5" (constant, can also be passed as float 0.5) - "beta * S * I / N" (mathematical formula) - "0.3 * sin(2 * pi * t / 365)" (time-dependent formula) - "2^10" or "2**10" (power: both syntaxes work)

stratified_rates list[StratifiedRate] | None

Stratification-specific rates. Each rate applies to compartments that match all specified stratification conditions.

condition Condition | None

Logical restrictions for the transition.

Source code in epimodel/context/dynamics.py
class Transition(BaseModel):
    """
    Defines a rule for system evolution.

    Attributes
    ----------
    id : str
        Id of the transition.
    source : list[str]
        The origin compartments.
    target : list[str]
        The destination compartments.
    rate : str | None
        Default mathematical formula, parameter name, or constant value for the flow.
        Used when no stratified rate matches. Numeric values are automatically
        converted to strings during validation.

        Operators: +, -, *, /, % (modulo), ^ or ** (power)
        Functions: sin, cos, tan, exp, ln, sqrt, abs, min, max, if, etc.
        Constants: pi, e

        Note: Both ^ and ** are supported for exponentiation (** is converted to ^).

        Examples:
        - "beta" (parameter reference)
        - "0.5" (constant, can also be passed as float 0.5)
        - "beta * S * I / N" (mathematical formula)
        - "0.3 * sin(2 * pi * t / 365)" (time-dependent formula)
        - "2^10" or "2**10" (power: both syntaxes work)
    stratified_rates : list[StratifiedRate] | None
        Stratification-specific rates. Each rate applies to compartments that match
        all specified stratification conditions.
    condition : Condition | None
        Logical restrictions for the transition.
    """

    id: str = Field(..., description="Id of the transition.")
    source: list[str] = Field(..., description="Origin compartments.")
    target: list[str] = Field(..., description="Destination compartments.")

    rate: str | None = Field(
        None,
        description=(
            "Default rate expression (fallback when no stratified rate matches). "
            "Can be a parameter reference (e.g., 'beta'), a constant (e.g., '0.5'), "
            "or a mathematical expression (e.g., 'beta * S * I / N'). "
            "Numeric values are automatically converted to strings during validation."
        ),
    )

    stratified_rates: list[StratifiedRate] | None = Field(
        None, description="List of stratification-specific rates"
    )

    condition: Condition | None = Field(
        None, description="Logical restrictions for the transition."
    )

    @field_validator("rate", mode="before")
    @classmethod
    def validate_rate(cls, value: str | None) -> str | None:
        """
        Convert numeric rates to strings and perform security and syntax validation.
        """
        if value is None:
            return value
        try:
            validate_expression_security(value)
            if epimodel_rs:
                epimodel_rs.core.MathExpression(value).py_validate()
        except ValueError as e:
            raise ValueError(f"Validation failed for rate '{value}': {e}")
        return value

Functions

validate_rate classmethod

validate_rate(value: str | None) -> str | None

Convert numeric rates to strings and perform security and syntax validation.

Source code in epimodel/context/dynamics.py
@field_validator("rate", mode="before")
@classmethod
def validate_rate(cls, value: str | None) -> str | None:
    """
    Convert numeric rates to strings and perform security and syntax validation.
    """
    if value is None:
        return value
    try:
        validate_expression_security(value)
        if epimodel_rs:
            epimodel_rs.core.MathExpression(value).py_validate()
    except ValueError as e:
        raise ValueError(f"Validation failed for rate '{value}': {e}")
    return value

options: show_root_heading: true show_source: true heading_level: 3

Initial Conditions

InitialConditions

Bases: BaseModel

Initial conditions for a simulation.

Attributes:

Name Type Description
population_size int

Population size.

disease_state_fractions list[DiseaseStateFraction]

List of disease state fractions. Each item contains a disease state id and its initial fractional size.

stratification_fractions (list[StratificationFractions], optional)

List of stratification fractions. Each item contains a stratification id and its category fractions.

Source code in epimodel/context/initial_conditions.py
class InitialConditions(BaseModel):
    """
    Initial conditions for a simulation.

    Attributes
    ----------
    population_size : int
        Population size.
    disease_state_fractions : list[DiseaseStateFraction]
        List of disease state fractions. Each item contains a disease state id and
        its initial fractional size.
    stratification_fractions : list[StratificationFractions], optional
        List of stratification fractions. Each item contains a stratification id and
        its category fractions.
    """

    population_size: int = Field(..., description="Population size.")
    disease_state_fractions: list[DiseaseStateFraction] = Field(
        ...,
        description=(
            "List of disease state fractions. Each item contains a disease state id "
            "and its initial fractional size."
        ),
    )
    stratification_fractions: list[StratificationFractions] = Field(
        default_factory=list,
        description=(
            "List of stratification fractions. Each item contains a stratification id "
            "and its category fractions."
        ),
    )

options: show_root_heading: true show_source: true heading_level: 3

Dynamics

Dynamics

Bases: BaseModel

Defines how the system evolves.

Attributes:

Name Type Description
typology Literal['DifferenceEquations']

The type of model.

transitions List[Transition]

A list of rules for state changes.

Source code in epimodel/context/dynamics.py
class Dynamics(BaseModel):
    """
    Defines how the system evolves.

    Attributes
    ----------
    typology : Literal["DifferenceEquations"]
        The type of model.
    transitions : List[Transition]
        A list of rules for state changes.
    """

    typology: Literal[ModelTypes.DIFFERENCE_EQUATIONS]
    transitions: list[Transition]

    @field_validator("transitions")
    @classmethod
    def validate_transitions_not_empty(cls, v: list[Transition]) -> list[Transition]:
        if not v:
            raise ValueError("At least one transition must be defined.")
        return v

    @model_validator(mode="after")
    def validate_unique_transition_ids(self) -> Self:
        """
        Validates that transition IDs are unique.
        """
        transition_ids = [t.id for t in self.transitions]
        if len(transition_ids) != len(set(transition_ids)):
            duplicates = [
                item for item in set(transition_ids) if transition_ids.count(item) > 1
            ]
            raise ValueError(f"Duplicate transition IDs found: {duplicates}")
        return self

Functions

validate_unique_transition_ids

validate_unique_transition_ids() -> Self

Validates that transition IDs are unique.

Source code in epimodel/context/dynamics.py
@model_validator(mode="after")
def validate_unique_transition_ids(self) -> Self:
    """
    Validates that transition IDs are unique.
    """
    transition_ids = [t.id for t in self.transitions]
    if len(transition_ids) != len(set(transition_ids)):
        duplicates = [
            item for item in set(transition_ids) if transition_ids.count(item) > 1
        ]
        raise ValueError(f"Duplicate transition IDs found: {duplicates}")
    return self

options: show_root_heading: true show_source: true heading_level: 3