Created
August 3, 2024 04:28
-
-
Save colonelpanic8/d33830f0c0cdbd973d7a801de93e711b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import numpy as np | |
import pytest | |
import sqlalchemy as sa | |
from railbird import containers, datatypes, shot_parsing, video | |
from railbird.datatypes import models | |
def assert_one_and_get(elems) -> shot_parsing.FeatureExtraction: | |
assert len(elems) == 1 | |
return elems[0] | |
@pytest.fixture | |
def get_feature_extractions(video_container_factory): | |
def _get(*args, video_container=None, start_frame=0, **kwargs): | |
container = video_container or video_container_factory(*args, **kwargs) | |
listener = container.write_feature_extractions_to_list_listener() | |
container.finish_processing_video(start_frame=start_frame) | |
container.wait_for_queues_to_empty(listener.queue, timeout=3) | |
return container.feature_extractions() | |
return _get | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_miss_overcut( | |
video_container_factory, | |
get_feature_extractions, | |
video_factory, | |
sync_sessionmaker, | |
user_factory, | |
identifier_path_factory, | |
): | |
path_to_video_file = identifier_path_factory( | |
"ef12a24848e9fda72f7f6038d31edc4089a4ec588e342230c940bc64255da31a", | |
"overcut.mp4", | |
) | |
container = video_container_factory(path_to_video_file, in_sample_videos=False) | |
container.pool_session_writer( | |
video_factory.default, user_id=user_factory.default.id | |
) | |
feature_extraction = assert_one_and_get( | |
get_feature_extractions(video_container=container) | |
) | |
assert feature_extraction.miss_angle is not None | |
assert feature_extraction.miss_angle > 0 | |
with sync_sessionmaker() as session: | |
with session.begin(): | |
result = session.execute(sa.select(models.ShotModel)) | |
shot = result.unique().scalars().one() | |
np.testing.assert_almost_equal( | |
float(shot.miss_features.miss_angle), | |
feature_extraction.miss_angle, | |
decimal=3, | |
) | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_miss_undercut( | |
video_container_factory, | |
get_feature_extractions, | |
video_factory, | |
sync_sessionmaker, | |
user_factory, | |
identifier_path_factory, | |
): | |
path_to_video_file = identifier_path_factory( | |
"ef12a24848e9fda72f7f6038d31edc4089a4ec588e342230c940bc64255da31a", | |
"undercut.mp4", | |
) | |
container = video_container_factory(path_to_video_file, in_sample_videos=False) | |
container.pool_session_writer( | |
video_factory.default, user_id=user_factory.default.id | |
) | |
feature_extraction = assert_one_and_get( | |
get_feature_extractions(video_container=container) | |
) | |
assert feature_extraction.miss_angle is not None | |
assert feature_extraction.miss_angle < 0 | |
with sync_sessionmaker() as session: | |
with session.begin(): | |
result = session.execute(sa.select(models.ShotModel)) | |
shot = result.unique().scalars().one() | |
np.testing.assert_almost_equal( | |
float(shot.miss_features.miss_angle), | |
feature_extraction.miss_angle, | |
decimal=3, | |
) | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_jaws_combos( | |
video_container_factory, | |
get_feature_extractions, | |
video_factory, | |
sync_sessionmaker, | |
user_factory, | |
): | |
container = video_container_factory("4.mov") | |
db_writer = container.pool_session_writer( | |
video_factory.default, user_id=user_factory.default.id | |
) | |
feature_extractions: List[shot_parsing.FeatureExtraction] = get_feature_extractions( | |
video_container=container | |
) | |
for feature_extraction in feature_extractions: | |
assert feature_extraction.make | |
assert feature_extraction.object_ball != feature_extraction.target_ball | |
container.wait_for_queues_to_empty(db_writer.feature_listener.queue) | |
assert len(feature_extractions) == 3 | |
with sync_sessionmaker() as session: | |
with session.begin(): | |
result = session.execute(sa.select(models.ShotModel)) | |
shots = result.unique().scalars().all() | |
assert len(shots) == 3 | |
for shot in shots: | |
assert shot.cue_object_features.cue_object_angle | |
assert shot.cue_object_features.cue_ball_speed2 | |
@pytest.mark.runs_model | |
def test_carom(get_feature_extractions): | |
feature_extraction = assert_one_and_get(get_feature_extractions("3.mov")) | |
assert feature_extraction.make | |
assert feature_extraction.object_ball == feature_extraction.target_ball | |
pocketing_path = feature_extraction.pocketing_paths[feature_extraction.object_ball][ | |
0 | |
] | |
assert feature_extraction.object_path_from_collision.info != pocketing_path.info | |
assert ( | |
feature_extraction.cue_ball | |
not in pocketing_path.info.start_info.ball_identifiers | |
) | |
assert len(pocketing_path.info.start_info.ball_identifiers) > 1 | |
# Make sure we can create a model from a feature extraction without an exception | |
models.ShotModel.create_from_feature_extraction(0, 1, feature_extraction) | |
@pytest.mark.runs_model | |
def test_shot_detection_and_paths(get_feature_extractions): | |
feature_extraction = assert_one_and_get(get_feature_extractions("2.webm")) | |
assert ( | |
feature_extraction.intention_info[2].identifier | |
== datatypes.PocketIdentifier.BOTTOM_RIGHT | |
) | |
assert feature_extraction.make | |
shot = feature_extraction.shot | |
# This has to be an assertion on == False because the value could be none | |
went_left = feature_extraction.object_ball_went_left | |
assert went_left is not None and not went_left | |
paths_info = shot.id_to_data[feature_extraction.cue_ball] | |
paths = paths_info.moving_path_metas | |
assert len(paths[0].end_info.ball_identifiers) == 2 | |
assert paths[1].end_info.wall_identifier == datatypes.WallIdentifier.RIGHT | |
assert paths[2].end_info.wall_identifier == datatypes.WallIdentifier.TOP | |
assert paths[3].end_info.wall_identifier == datatypes.WallIdentifier.BOTTOM | |
@pytest.mark.runs_model | |
def test_a_few_makes( | |
get_feature_extractions, identifier_path_factory, testing_container | |
): | |
feature_extractions: List[shot_parsing.FeatureExtraction] = get_feature_extractions( | |
video_container=containers.VideoSourceContainer( | |
rbd=testing_container, | |
video_source=video.SegmentedVideoSource( | |
video.SimpleSegmentedVideoInfo( | |
lambda index: identifier_path_factory( | |
"12b27c3a-489f-49b8-93db-b6f13b4f43c0", f"{index:03}.mp4" | |
) | |
) | |
), | |
) | |
) | |
assert len(feature_extractions) == 3 | |
cue_balls = set() | |
for feature_extraction in feature_extractions: | |
assert feature_extraction.make | |
assert feature_extraction.difficulty | |
cue_balls.add(feature_extraction.cue_ball) | |
assert len(cue_balls) == 1 | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_is_direct_shot( | |
get_feature_extractions, | |
sync_session, | |
user_factory, | |
video_factory, | |
identifier_path_factory, | |
): | |
path_to_video_file = identifier_path_factory( | |
"80eb94a76f8069b62d0cdfcf52d3a7fb49b4fd60723e51d3bbd1df25d015ed3c", | |
"direct1.mp4", | |
) | |
feature_extraction = assert_one_and_get( | |
get_feature_extractions(path_to_video_file, in_sample_videos=False) | |
) | |
assert feature_extraction.is_direct | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_cue_features_after_object( | |
get_feature_extractions, | |
sync_session, | |
user_factory, | |
video_factory, | |
identifier_path_factory, | |
): | |
path_to_video_file = identifier_path_factory( | |
"4f292947c860a427a60e8eb6703c86cd32c6eeb2b71fcc40839a2f98db67eda4", | |
"cue-features-after-object1.mp4", | |
) | |
feature_extraction = assert_one_and_get( | |
get_feature_extractions(path_to_video_file, in_sample_videos=False) | |
) | |
assert feature_extraction.cue_speed_after_object is not None | |
assert feature_extraction.cue_angle_after_object_degrees is not None | |
assert 20 < feature_extraction.cue_speed_after_object < 50 | |
assert 80 < feature_extraction.cue_angle_after_object_degrees < 100 | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_kick_shot( | |
get_feature_extractions, | |
sync_session, | |
user_factory, | |
video_factory, | |
identifier_path_factory, | |
): | |
path_to_video_file = identifier_path_factory( | |
"77302f0cc50a1e143d4a8e8037a240add45f5059a207e1d4e58e33f6c1e61ca9", | |
"kick1.mp4", | |
) | |
feature_extraction = assert_one_and_get( | |
get_feature_extractions(path_to_video_file, in_sample_videos=False) | |
) | |
assert feature_extraction.is_kick | |
assert feature_extraction.kick_walls is not None | |
assert len(feature_extraction.kick_walls) > 0 | |
assert feature_extraction.kick_angle is not None | |
assert feature_extraction.kick_angle > 5 | |
assert feature_extraction.kick_angle < 80 | |
with sync_session.begin(): | |
shot = models.ShotModel.create_from_feature_extraction( | |
video_factory.default.id, user_factory.default.id, feature_extraction | |
) | |
sync_session.add(shot) | |
with sync_session.begin(): | |
loaded_shot = models.ShotModel.get_by_id(sync_session, shot.id) | |
assert loaded_shot.kick_features is not None | |
assert loaded_shot.kick_features.angle < 80 | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_serialize_path( | |
get_feature_extractions, sync_session, user_factory, video_factory | |
): | |
feature_extraction = assert_one_and_get(get_feature_extractions("bank1.mp4")) | |
with sync_session.begin(): | |
shot = models.ShotModel.create_from_feature_extraction( | |
video_factory.default.id, user_factory.default.id, feature_extraction | |
) | |
sync_session.add(shot) | |
with sync_session.begin(): | |
loaded_shot = models.ShotModel.get_by_id(sync_session, shot.id) | |
assert len(loaded_shot.serialized_shot_paths.pathed_shot.id_to_data) == 5 | |
@pytest.mark.writes_db | |
@pytest.mark.runs_model | |
def test_process_from_arbitrary_frame( | |
video_container_factory, | |
get_feature_extractions, | |
video_factory, | |
sync_sessionmaker, | |
user_factory, | |
): | |
# Process entire video | |
container = video_container_factory("4.mov") | |
original_video = video_factory() | |
db_writer = container.pool_session_writer( | |
original_video, user_id=user_factory.default.id | |
) | |
feature_extractions: List[shot_parsing.FeatureExtraction] = get_feature_extractions( | |
video_container=container | |
) | |
container.wait_for_queues_to_empty(db_writer.feature_listener.queue) | |
assert len(feature_extractions) == 3 | |
with sync_sessionmaker() as session: | |
with session.begin(): | |
original_video = models.VideoModel.get_by_id( | |
session, original_video.id, eager=True | |
) | |
original_shots = original_video.shots | |
min_end_frame = min(shot.end_frame for shot in original_shots) | |
# Process video from min_end_frame | |
partial_container = video_container_factory("4.mov") | |
partial_video: models.VideoModel = video_factory() | |
partial_db_writer = partial_container.pool_session_writer( | |
partial_video, user_id=user_factory.default.id | |
) | |
partial_feature_extractions: List[shot_parsing.FeatureExtraction] = ( | |
get_feature_extractions( | |
video_container=partial_container, start_frame=min_end_frame | |
) | |
) | |
assert len(partial_feature_extractions) == 2 | |
partial_container.wait_for_queues_to_empty(partial_db_writer.feature_listener.queue) | |
with sync_sessionmaker() as session: | |
with session.begin(): | |
partial_video = models.VideoModel.get_by_id( | |
session, partial_video.id, eager=True | |
) | |
partial_shots = partial_video.shots | |
def frames_almost_equal(frame1, frame2, tolerance: int = 5): | |
return abs(frame1 - frame2) <= tolerance | |
assert len(original_shots[1:]) == len(partial_shots) | |
for orig_shot, partial_shot in zip(original_shots[1:], partial_shots): | |
assert frames_almost_equal(orig_shot.start_frame, partial_shot.start_frame) | |
assert frames_almost_equal(orig_shot.end_frame, partial_shot.end_frame) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment