-
-
Save Hammer2900/1e8a47e4c5aa84732d6537a382db3ca6 to your computer and use it in GitHub Desktop.
Полный пример реализации и использования Entity Component System на python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Any, Type | |
EntityId = str | |
Component = object | |
@dataclass | |
class StoredSystem: | |
variables: dict[str, Any] | |
components: dict[str, Type[Component]] # key is argument name | |
has_entity_id_argument: bool | |
has_ecs_argument: bool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import Callable, Type, Any, Iterator | |
from ecs_types import EntityId, Component, StoredSystem | |
from unique_id import UniqueIdGenerator | |
class EntityComponentSystem: | |
def __init__(self, on_create: Callable[[EntityId, list[Component]], None] = None, | |
on_remove: Callable[[EntityId], None] = None): | |
""" | |
:param on_create: | |
Хук, отрабатывающий при создании сущности, | |
например может пригодиться, если сервер сообщает клиентам о появлении новых сущностей | |
:param on_remove: | |
Хук, отрабатывающий перед удалением сущности | |
""" | |
# Здесь хранятся все системы вместе с полученными от них сигнатурами | |
self._systems: dict[Callable, StoredSystem] = {} | |
# По типу компонента хранятся словари, содержащие сами компоненты по ключам entity_id | |
self._components: dict[Type[Component], dict[EntityId, Component]] = {} | |
self._entities: list[EntityId] = [] | |
self._vars = {} | |
self.on_create = on_create | |
self.on_remove = on_remove | |
def _unsafe_get_component(self, entity_id: EntityId, component_class: Type[Component]) -> Component: | |
""" | |
Возвращает компонент сущности с типом переданного класса component_class | |
Кидает KeyError если сущность не существует или не имеет такого компонента | |
""" | |
return self._components[component_class][entity_id] | |
def init_component(self, component_class: Type[Component]) -> None: | |
""" | |
Инициализация класса компонента. Следует вызвать до создания сущностей | |
""" | |
self._components[component_class] = {} | |
def add_variable(self, variable_name: str, variable_value: Any) -> None: | |
""" | |
Инициализация переменной. Далее может быть запрошена любой системой. | |
""" | |
self._vars[variable_name] = variable_value | |
def init_system(self, system: Callable): | |
""" | |
Инициализация системы. Если система зависит от внешней переменной - передайте её в add_variable до инициализации. | |
Внешние переменные и специальные аргументы (ecs: EntityComponentSystem и entity_id: EntityId) запрашиваются | |
через указание имени аргумента в функции системы. | |
Запрашиваемые компоненты указываются через указание типа аргумента (например dummy_health: HealthComponent). | |
Название аргумента в таком случае может быть названо как угодно. | |
Запрашиваемый компонент должен быть инициализирован до инициализации системы | |
""" | |
stored_system = StoredSystem( | |
components={}, | |
variables={}, | |
has_entity_id_argument=False, | |
has_ecs_argument=False | |
) | |
# Через сигнатуру функции системы узнаем какие данные и компоненты она запрашивает. | |
# Сохраним в StoredSystem чтобы не перепроверять сигнатуру каждый кадр. | |
system_params = inspect.signature(system).parameters | |
for param_name, param in system_params.items(): | |
if param_name == 'entity_id': # Система может требовать конкретный entity_id для переданных компонентов | |
stored_system.has_entity_id_argument = True | |
elif param_name == 'ecs': # Системе может потребоваться ссылка на ecs. Например, для удаления сущностей | |
stored_system.has_ecs_argument = True | |
elif param.annotation in self._components: | |
stored_system.components[param_name] = param.annotation | |
elif param_name in self._vars: | |
stored_system.variables[param_name] = self._vars[param_name] | |
else: | |
raise Exception(f'Wrong argument: {param_name}') | |
self._systems[system] = stored_system | |
def create_entity(self, components: list[Component], entity_id=None) -> EntityId: | |
""" | |
Создание сущности на основе списка его компонентов | |
Можно задавать свой entity_id но он обязан быть уникальным | |
""" | |
if entity_id is None: | |
entity_id = UniqueIdGenerator.generate_id() | |
else: | |
assert entity_id not in self._entities, f"Entity with id {entity_id} already exists" | |
for component in components: | |
self._components[component.__class__][entity_id] = component | |
self._entities.append(entity_id) | |
if self.on_create: | |
self.on_create(entity_id, components) | |
return entity_id | |
def get_entity_ids_with_components(self, *component_classes) -> set[EntityId]: | |
""" | |
Получить все entity_id у которых есть каждый из компонентов, указанных в component_classes | |
""" | |
if not component_classes: | |
return set(self._entities) | |
# Если запрошено несколько компонентов - то следует вернуть сущности, обладающие каждым из них | |
# Это достигается пересечением множеств entity_id по классу компонента | |
entities = set.intersection(*[set(self._components[component_class].keys()) | |
for component_class in component_classes]) | |
return entities | |
def get_entities_with_components(self, *component_classes) -> Iterator[tuple[EntityId, list[Component]]]: | |
""" | |
Получить все entity_id вместе с указанными компонентами | |
""" | |
for entity_id in self.get_entity_ids_with_components(*component_classes): | |
components = tuple(self._unsafe_get_component(entity_id, component_class) | |
for component_class in component_classes) | |
yield entity_id, components | |
def update(self) -> None: | |
""" | |
Вызывает все системы. | |
Следует вызывать в игровом цикле. | |
""" | |
for system_function, system in self._systems.items(): | |
for entity_id in self.get_entity_ids_with_components(*system.components.values()): | |
special_args = {} | |
if system.has_ecs_argument: | |
special_args['ecs'] = self | |
if system.has_entity_id_argument: | |
special_args['entity_id'] = entity_id | |
# Сделано для того чтобы в системе можно было указывать любые имена для запрашиваемых компонентов | |
required_components_arguments = {param: self._unsafe_get_component(entity_id, component_name) for | |
param, component_name in | |
system.components.items()} | |
system_function(**(required_components_arguments | system.variables | special_args)) | |
def remove_entity(self, entity_id: EntityId): | |
""" | |
Удаляет сущность | |
""" | |
if self.on_remove is not None: | |
self.on_remove(entity_id) | |
for components in self._components.values(): | |
components.pop(entity_id, None) | |
self._entities.remove(entity_id) | |
def get_component(self, entity_id: EntityId, component_class: Type[Component]): | |
""" | |
:return | |
Возвращает компонент сущности с типом переданного класса component_class | |
Возвращает None если сущность не существует или не имеет такого компонента | |
""" | |
return self._components[component_class].get(entity_id, None) | |
def get_components(self, entity_id: EntityId, component_classes): | |
""" | |
:return | |
Возвращает требуемые компоненты сущности. | |
Возвращает None если сущность не существует или не имеет всех этих компонентов | |
""" | |
try: | |
return tuple(self._unsafe_get_component(entity_id, component_class) | |
for component_class in component_classes) | |
except KeyError: | |
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Protocol, Type, TypeVar, overload, Callable, Any, Iterator | |
from ecs_types import EntityId, Component, StoredSystem | |
Component1 = TypeVar('Component1') | |
Component2 = TypeVar('Component2') | |
Component3 = TypeVar('Component3') | |
Component4 = TypeVar('Component4') | |
class EntityComponentSystem(Protocol): | |
_systems: dict[Callable, StoredSystem] | |
_components: dict[Type[Component], dict[EntityId, Component]] | |
_entities: list[EntityId] | |
_vars: dict[str, Any] | |
on_create: Callable[[EntityId, list[Component]], None] | |
on_remove: Callable[[EntityId], None] | |
def __init__(self, on_create: Callable[[EntityId, list[Component]], None] = None, | |
on_remove: Callable[[EntityId], None] = None): ... | |
@overload | |
def _unsafe_get_component(self, entity_id: str, component_class: Type[Component1]) -> Component1: ... | |
@overload | |
def init_component(self, component_class: Type[Component1]) -> None: ... | |
@overload | |
def init_system(self, system: Callable): ... | |
@overload | |
def add_variable(self, variable_name: str, variable_value: Any) -> None: ... | |
@overload | |
def create_entity(self, components: list[Component1], entity_id=None) -> EntityId: ... | |
@overload | |
def get_entity_ids_with_components(self, class1: Type[Component1]) -> set[EntityId]: ... | |
@overload | |
def get_entity_ids_with_components(self, class1: Type[Component1], class2: Type[Component2]) -> set[EntityId]: ... | |
@overload | |
def get_entity_ids_with_components(self, class1: Type[Component1], class2: Type[Component2], class3: Type[Component3]) -> set[EntityId]: ... | |
@overload | |
def get_entity_ids_with_components(self, class1: Type[Component1], class2: Type[Component2], class3: Type[Component3], class4: Type[Component4]) -> set[EntityId]: ... | |
@overload | |
def get_entities_with_components(self, class1: Type[Component1]) -> Iterator[tuple[ | |
EntityId, tuple[Component1]]]: ... | |
@overload | |
def get_entities_with_components(self, class1: Type[Component1], class2: Type[Component2]) -> Iterator[ | |
tuple[ | |
EntityId, tuple[Component1, Component2]]]: ... | |
@overload | |
def get_entities_with_components(self, class1: Type[Component1], class2: Type[Component2], class3: Type[Component3]) -> \ | |
Iterator[tuple[EntityId, tuple[Component1, Component2, Component3]]]: ... | |
@overload | |
def get_entities_with_components(self, class1: Type[Component1], class2: Type[Component2], class3: Type[Component3], class4: Type[Component4]) -> Iterator[tuple[ | |
EntityId, tuple[Component1, Component2, Component3, Component4]]]: ... | |
def update(self) -> None: ... | |
def remove_entity(self, entity_id: EntityId): ... | |
def get_component(self, entity_id: EntityId, component_class: Type[Component1]) -> Component1: ... | |
@overload | |
def get_components(self, entity_id: EntityId, | |
component_classes: tuple[Type[Component1]]) -> tuple[Component1]: ... | |
@overload | |
def get_components(self, entity_id: EntityId, | |
component_classes: tuple[Type[Component1], Type[Component2]]) -> tuple[ | |
Component1, Component2]: ... | |
@overload | |
def get_components(self, entity_id: EntityId, | |
component_classes: tuple[Type[Component1], Type[Component2], Type[Component3]]) -> tuple[ | |
Component1, Component2, Component3]: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from entity_component_system import EntityComponentSystem | |
from ecs_types import EntityId | |
from dataclasses import dataclass, field | |
import math | |
import pygame | |
from pygame import Color | |
from pygame.time import Clock | |
@dataclass(slots=True) | |
class ColliderComponent: | |
x: float | |
y: float | |
radius: float | |
def distance(self, other: 'ColliderComponent'): | |
return math.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2) | |
def is_intersecting(self, other: 'ColliderComponent'): | |
return self.distance(other) <= self.radius + other.radius | |
@dataclass(slots=True) | |
class VelocityComponent: | |
speed_x: float = 0.0 | |
speed_y: float = 0.0 | |
@dataclass(slots=True) | |
class DamageOnContactComponent: | |
damage: int | |
die_on_contact: bool = True | |
def create_arrow(x: float, y: float, angle: int, speed: float, damage: int): | |
arrow_radius = 15 | |
return [ | |
ColliderComponent(x, y, arrow_radius), | |
VelocityComponent( | |
speed_x=math.cos(math.radians(angle)) * speed, | |
speed_y=math.sin(math.radians(angle)) * speed | |
), | |
DamageOnContactComponent(damage) | |
] | |
@dataclass(slots=True) | |
class HealthComponent: | |
max_amount: int | |
amount: int = field(default=None) | |
def __post_init__(self): | |
if self.amount is None: | |
self.amount = self.max_amount | |
def apply_damage(self, damage: int): | |
self.amount = max(0, self.amount - damage) | |
def create_dummy(x: float, y: float, health: int): | |
dummy_radius = 50 | |
return [ | |
ColliderComponent(x, y, dummy_radius), | |
HealthComponent( | |
max_amount=health, | |
) | |
] | |
def velocity_system(velocity: VelocityComponent, collider: ColliderComponent): | |
collider.x += velocity.speed_x | |
collider.y += velocity.speed_y | |
def damage_on_contact_system(entity_id: EntityId, | |
ecs: EntityComponentSystem, | |
damage_on_contact: DamageOnContactComponent, | |
collider: ColliderComponent): | |
for enemy_id, (enemy_health, enemy_collider) in ecs.get_entities_with_components(HealthComponent, | |
ColliderComponent): | |
if collider.is_intersecting(enemy_collider): | |
enemy_health.apply_damage(damage_on_contact.damage) | |
if damage_on_contact.die_on_contact: | |
ecs.remove_entity(entity_id) | |
return | |
def death_system(entity_id: EntityId, health: HealthComponent, ecs: EntityComponentSystem): | |
if health.amount <= 0: | |
ecs.remove_entity(entity_id) | |
ecs = EntityComponentSystem() | |
ecs.init_component(ColliderComponent) | |
ecs.init_component(VelocityComponent) | |
ecs.init_component(DamageOnContactComponent) | |
ecs.init_component(HealthComponent) | |
ecs.init_system(velocity_system) | |
ecs.init_system(damage_on_contact_system) | |
ecs.init_system(death_system) | |
ecs.create_entity(create_arrow(x=0, y=0, angle=45, speed=2, damage=50)) | |
ecs.create_entity(create_arrow(x=500, y=0, angle=135, speed=1.5, damage=50)) | |
ecs.create_entity(create_arrow(x=0, y=500, angle=-45, speed=1.1, damage=50)) | |
ecs.create_entity(create_arrow(x=500, y=500, angle=-135, speed=1, damage=50)) | |
ecs.create_entity(create_dummy(x=250, y=250, health=200)) | |
screen = pygame.display.set_mode((500, 500)) | |
running = True | |
clock = Clock() | |
while running: | |
for event in pygame.event.get(): | |
if event.type == pygame.QUIT: | |
running = False | |
ecs.update() | |
screen.fill((93, 161, 48)) | |
for entity_id, (collider,) in ecs.get_entities_with_components(ColliderComponent): | |
pygame.draw.circle(screen, Color('gray'), (collider.x, collider.y), collider.radius) | |
pygame.display.flip() | |
clock.tick(60) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UniqueIdGenerator: | |
last_id = 0 | |
@classmethod | |
def generate_id(cls) -> str: | |
cls.last_id += 1 | |
return str(cls.last_id) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment