Skip to content

Instantly share code, notes, and snippets.

@colonelpanic8
Created August 3, 2024 04:28
Show Gist options
  • Save colonelpanic8/d33830f0c0cdbd973d7a801de93e711b to your computer and use it in GitHub Desktop.
Save colonelpanic8/d33830f0c0cdbd973d7a801de93e711b to your computer and use it in GitHub Desktop.
from typing import List
import numpy as np
import pytest
import sqlalchemy as sa
from railbird import containers, datatypes, shot_parsing, video
from railbird.datatypes import models
def assert_one_and_get(elems) -> shot_parsing.FeatureExtraction:
assert len(elems) == 1
return elems[0]
@pytest.fixture
def get_feature_extractions(video_container_factory):
def _get(*args, video_container=None, start_frame=0, **kwargs):
container = video_container or video_container_factory(*args, **kwargs)
listener = container.write_feature_extractions_to_list_listener()
container.finish_processing_video(start_frame=start_frame)
container.wait_for_queues_to_empty(listener.queue, timeout=3)
return container.feature_extractions()
return _get
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_miss_overcut(
video_container_factory,
get_feature_extractions,
video_factory,
sync_sessionmaker,
user_factory,
identifier_path_factory,
):
path_to_video_file = identifier_path_factory(
"ef12a24848e9fda72f7f6038d31edc4089a4ec588e342230c940bc64255da31a",
"overcut.mp4",
)
container = video_container_factory(path_to_video_file, in_sample_videos=False)
container.pool_session_writer(
video_factory.default, user_id=user_factory.default.id
)
feature_extraction = assert_one_and_get(
get_feature_extractions(video_container=container)
)
assert feature_extraction.miss_angle is not None
assert feature_extraction.miss_angle > 0
with sync_sessionmaker() as session:
with session.begin():
result = session.execute(sa.select(models.ShotModel))
shot = result.unique().scalars().one()
np.testing.assert_almost_equal(
float(shot.miss_features.miss_angle),
feature_extraction.miss_angle,
decimal=3,
)
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_miss_undercut(
video_container_factory,
get_feature_extractions,
video_factory,
sync_sessionmaker,
user_factory,
identifier_path_factory,
):
path_to_video_file = identifier_path_factory(
"ef12a24848e9fda72f7f6038d31edc4089a4ec588e342230c940bc64255da31a",
"undercut.mp4",
)
container = video_container_factory(path_to_video_file, in_sample_videos=False)
container.pool_session_writer(
video_factory.default, user_id=user_factory.default.id
)
feature_extraction = assert_one_and_get(
get_feature_extractions(video_container=container)
)
assert feature_extraction.miss_angle is not None
assert feature_extraction.miss_angle < 0
with sync_sessionmaker() as session:
with session.begin():
result = session.execute(sa.select(models.ShotModel))
shot = result.unique().scalars().one()
np.testing.assert_almost_equal(
float(shot.miss_features.miss_angle),
feature_extraction.miss_angle,
decimal=3,
)
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_jaws_combos(
video_container_factory,
get_feature_extractions,
video_factory,
sync_sessionmaker,
user_factory,
):
container = video_container_factory("4.mov")
db_writer = container.pool_session_writer(
video_factory.default, user_id=user_factory.default.id
)
feature_extractions: List[shot_parsing.FeatureExtraction] = get_feature_extractions(
video_container=container
)
for feature_extraction in feature_extractions:
assert feature_extraction.make
assert feature_extraction.object_ball != feature_extraction.target_ball
container.wait_for_queues_to_empty(db_writer.feature_listener.queue)
assert len(feature_extractions) == 3
with sync_sessionmaker() as session:
with session.begin():
result = session.execute(sa.select(models.ShotModel))
shots = result.unique().scalars().all()
assert len(shots) == 3
for shot in shots:
assert shot.cue_object_features.cue_object_angle
assert shot.cue_object_features.cue_ball_speed2
@pytest.mark.runs_model
def test_carom(get_feature_extractions):
feature_extraction = assert_one_and_get(get_feature_extractions("3.mov"))
assert feature_extraction.make
assert feature_extraction.object_ball == feature_extraction.target_ball
pocketing_path = feature_extraction.pocketing_paths[feature_extraction.object_ball][
0
]
assert feature_extraction.object_path_from_collision.info != pocketing_path.info
assert (
feature_extraction.cue_ball
not in pocketing_path.info.start_info.ball_identifiers
)
assert len(pocketing_path.info.start_info.ball_identifiers) > 1
# Make sure we can create a model from a feature extraction without an exception
models.ShotModel.create_from_feature_extraction(0, 1, feature_extraction)
@pytest.mark.runs_model
def test_shot_detection_and_paths(get_feature_extractions):
feature_extraction = assert_one_and_get(get_feature_extractions("2.webm"))
assert (
feature_extraction.intention_info[2].identifier
== datatypes.PocketIdentifier.BOTTOM_RIGHT
)
assert feature_extraction.make
shot = feature_extraction.shot
# This has to be an assertion on == False because the value could be none
went_left = feature_extraction.object_ball_went_left
assert went_left is not None and not went_left
paths_info = shot.id_to_data[feature_extraction.cue_ball]
paths = paths_info.moving_path_metas
assert len(paths[0].end_info.ball_identifiers) == 2
assert paths[1].end_info.wall_identifier == datatypes.WallIdentifier.RIGHT
assert paths[2].end_info.wall_identifier == datatypes.WallIdentifier.TOP
assert paths[3].end_info.wall_identifier == datatypes.WallIdentifier.BOTTOM
@pytest.mark.runs_model
def test_a_few_makes(
get_feature_extractions, identifier_path_factory, testing_container
):
feature_extractions: List[shot_parsing.FeatureExtraction] = get_feature_extractions(
video_container=containers.VideoSourceContainer(
rbd=testing_container,
video_source=video.SegmentedVideoSource(
video.SimpleSegmentedVideoInfo(
lambda index: identifier_path_factory(
"12b27c3a-489f-49b8-93db-b6f13b4f43c0", f"{index:03}.mp4"
)
)
),
)
)
assert len(feature_extractions) == 3
cue_balls = set()
for feature_extraction in feature_extractions:
assert feature_extraction.make
assert feature_extraction.difficulty
cue_balls.add(feature_extraction.cue_ball)
assert len(cue_balls) == 1
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_is_direct_shot(
get_feature_extractions,
sync_session,
user_factory,
video_factory,
identifier_path_factory,
):
path_to_video_file = identifier_path_factory(
"80eb94a76f8069b62d0cdfcf52d3a7fb49b4fd60723e51d3bbd1df25d015ed3c",
"direct1.mp4",
)
feature_extraction = assert_one_and_get(
get_feature_extractions(path_to_video_file, in_sample_videos=False)
)
assert feature_extraction.is_direct
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_cue_features_after_object(
get_feature_extractions,
sync_session,
user_factory,
video_factory,
identifier_path_factory,
):
path_to_video_file = identifier_path_factory(
"4f292947c860a427a60e8eb6703c86cd32c6eeb2b71fcc40839a2f98db67eda4",
"cue-features-after-object1.mp4",
)
feature_extraction = assert_one_and_get(
get_feature_extractions(path_to_video_file, in_sample_videos=False)
)
assert feature_extraction.cue_speed_after_object is not None
assert feature_extraction.cue_angle_after_object_degrees is not None
assert 20 < feature_extraction.cue_speed_after_object < 50
assert 80 < feature_extraction.cue_angle_after_object_degrees < 100
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_kick_shot(
get_feature_extractions,
sync_session,
user_factory,
video_factory,
identifier_path_factory,
):
path_to_video_file = identifier_path_factory(
"77302f0cc50a1e143d4a8e8037a240add45f5059a207e1d4e58e33f6c1e61ca9",
"kick1.mp4",
)
feature_extraction = assert_one_and_get(
get_feature_extractions(path_to_video_file, in_sample_videos=False)
)
assert feature_extraction.is_kick
assert feature_extraction.kick_walls is not None
assert len(feature_extraction.kick_walls) > 0
assert feature_extraction.kick_angle is not None
assert feature_extraction.kick_angle > 5
assert feature_extraction.kick_angle < 80
with sync_session.begin():
shot = models.ShotModel.create_from_feature_extraction(
video_factory.default.id, user_factory.default.id, feature_extraction
)
sync_session.add(shot)
with sync_session.begin():
loaded_shot = models.ShotModel.get_by_id(sync_session, shot.id)
assert loaded_shot.kick_features is not None
assert loaded_shot.kick_features.angle < 80
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_serialize_path(
get_feature_extractions, sync_session, user_factory, video_factory
):
feature_extraction = assert_one_and_get(get_feature_extractions("bank1.mp4"))
with sync_session.begin():
shot = models.ShotModel.create_from_feature_extraction(
video_factory.default.id, user_factory.default.id, feature_extraction
)
sync_session.add(shot)
with sync_session.begin():
loaded_shot = models.ShotModel.get_by_id(sync_session, shot.id)
assert len(loaded_shot.serialized_shot_paths.pathed_shot.id_to_data) == 5
@pytest.mark.writes_db
@pytest.mark.runs_model
def test_process_from_arbitrary_frame(
video_container_factory,
get_feature_extractions,
video_factory,
sync_sessionmaker,
user_factory,
):
# Process entire video
container = video_container_factory("4.mov")
original_video = video_factory()
db_writer = container.pool_session_writer(
original_video, user_id=user_factory.default.id
)
feature_extractions: List[shot_parsing.FeatureExtraction] = get_feature_extractions(
video_container=container
)
container.wait_for_queues_to_empty(db_writer.feature_listener.queue)
assert len(feature_extractions) == 3
with sync_sessionmaker() as session:
with session.begin():
original_video = models.VideoModel.get_by_id(
session, original_video.id, eager=True
)
original_shots = original_video.shots
min_end_frame = min(shot.end_frame for shot in original_shots)
# Process video from min_end_frame
partial_container = video_container_factory("4.mov")
partial_video: models.VideoModel = video_factory()
partial_db_writer = partial_container.pool_session_writer(
partial_video, user_id=user_factory.default.id
)
partial_feature_extractions: List[shot_parsing.FeatureExtraction] = (
get_feature_extractions(
video_container=partial_container, start_frame=min_end_frame
)
)
assert len(partial_feature_extractions) == 2
partial_container.wait_for_queues_to_empty(partial_db_writer.feature_listener.queue)
with sync_sessionmaker() as session:
with session.begin():
partial_video = models.VideoModel.get_by_id(
session, partial_video.id, eager=True
)
partial_shots = partial_video.shots
def frames_almost_equal(frame1, frame2, tolerance: int = 5):
return abs(frame1 - frame2) <= tolerance
assert len(original_shots[1:]) == len(partial_shots)
for orig_shot, partial_shot in zip(original_shots[1:], partial_shots):
assert frames_almost_equal(orig_shot.start_frame, partial_shot.start_frame)
assert frames_almost_equal(orig_shot.end_frame, partial_shot.end_frame)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment