Last active
September 26, 2022 14:16
-
-
Save neuromaancer/ada3534ded3c1aa6a3f39e83ef930a1f to your computer and use it in GitHub Desktop.
utils
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
@Created Date: Friday February 4th 2022 | |
@Author: Alafate Abulimiti | |
@Company: INRIA | |
@Lab: CoML/Articulab | |
@School: PSL/ENS | |
@Description: Save the frequent useful functions | |
-------------- | |
@HISTORY: | |
Date By Comments | |
---------------- ----- ----------------------------------------------------------------- | |
13-02-2022 06:12:32 Alafate Abulimiti add get_segments function | |
13-02-2022 05:47:41 Alafate Abulimiti add check_identity function | |
7-02-2022 01:25:50 Alafate Abulimiti modify extract dyad and session function with regex | |
4-02-2022 03:19:36 Alafate Abulimiti add insert_row function | |
4-02-2022 02:01:23 Alafate Abulimiti add rename file function | |
4-02-2022 01:29:37 Alafate add get role pair function | |
4-02-2022 12:49:8 Alafate add round timestamps function | |
4-02-2022 11:48:45 Alafate add extract dyad session from a string | |
""" | |
import pandas as pd | |
from pathlib import Path | |
from rich import print as rprint | |
import re | |
def extract_dyad_session(s): | |
""" | |
extract_dyad_session extract dyad and session from a string with "_" as the delimiter | |
Args: | |
s (str): string | |
Returns: | |
int: dyad and session in int format | |
""" | |
dyad = re.search(r"D[0-9]{1,2}", s).group(0).replace("D", "") | |
session = re.search(r"S[0-9]{1}", s).group(0).replace("S", "") | |
return int(dyad), int(session) | |
def round_timestamps( | |
df: pd.DataFrame, | |
begin: str = "Begin Time - hh:mm:ss.ms", | |
end: str = "End Time - hh:mm:ss.ms", | |
level: str = "100ms", | |
) -> pd.DataFrame: | |
""" | |
round_timestamps round the timestamps of for dataframe | |
Args: | |
df (DataFrame): dataframe with timestamps, normally it is a transcript file or an annotation file. | |
begin (str, optional): Begin time column name. Defaults to "Begin Time - hh:mm:ss.ms". | |
end (str, optional): End time column name. Defaults to "End Time - hh:mm:ss.ms". | |
level (str, optional): Round level. Defaults to "100ms". | |
Returns: | |
DataFrame: new df with rounded timestamps | |
""" | |
# df[begin] = pd.to_datetime(df[begin], format="%H:%M:%S.%f").apply( | |
# pd.Timestamp.ceil, args=(level,).dt.time | |
# ) | |
df[begin] = pd.to_datetime(df[begin], format="%H:%M:%S.%f").apply( | |
pd.Timestamp.ceil, args=(level,) | |
).dt.time | |
df[end] = pd.to_datetime(df[end], format="%H:%M:%S.%f").apply( | |
pd.Timestamp.ceil, args=(level,) | |
).dt.time | |
return df | |
def get_role_pair( | |
df: pd.DataFrame, period: str, dyad: str, session: str, key="p" | |
) -> dict: | |
""" | |
get_role_pair get role pair by period and dyad and session | |
Args: | |
df (pd.DataFrame): df with period, dyad, session,and role | |
period (str): period -> T: task, S: social | |
dyad (str): dyad | |
session (str): session: 1, 2 | |
Returns: | |
dict: role pair dictionary, key is person, value is role. | |
""" | |
r = df.loc[ | |
(df["Dyad"] == dyad) & (df["Session"] == session) & (df["Period"] == period) | |
] | |
role_pair = r["Role"].tolist() | |
participant_pair = r["Participant"].tolist() | |
if key == "p": | |
return dict(zip(participant_pair, role_pair)) | |
elif key == "r": | |
return dict(zip(role_pair, participant_pair)) | |
def rename_file(f: Path, addition: str, position: str = "postfix") -> Path: | |
""" | |
rename_file rename a file with a postfix or prefix | |
Args: | |
f (Path): File path | |
addition (str): Additional string | |
position (str, optional): Additional string postion. Defaults to "postfix". | |
Returns: | |
Path: Renamed file `pathlib` Path object | |
""" | |
if position == "postfix": | |
return Path(*f.parts[:-1]) / Path(f.stem + "_" + addition + f.suffix) | |
elif position == "prefix": | |
return Path(*f.parts[:-1]) / Path(addition + "_" + f.stem + f.suffix) | |
def insert_row(df: pd.DataFrame, row_number: int, row_value: dict) -> pd.DataFrame: | |
""" | |
insert_row insert a row in a dataframe at a given row number | |
Args: | |
df (pd.DataFrame): target dataframe | |
row_number (int): row number | |
row_value (dict): row value with dict format {column_name: value} | |
Returns: | |
pd.DataFrame: new dataframe with inserted row | |
""" | |
# Slice the upper half of the dataframe | |
df1 = df[0:row_number] | |
# Store the result of lower half of the dataframe | |
df2 = df[row_number:] | |
# Insert the row in the upper half dataframe | |
df1.loc[row_number] = row_value | |
# Concat the two dataframes | |
df_result = pd.concat([df1, df2]) | |
# Reassign the index labels | |
df_result.index = [*range(df_result.shape[0])] | |
# Return the updated dataframe | |
return df_result | |
def check_identity( | |
rapport_df: pd.DataFrame, | |
line: int, | |
reference: int, | |
cols: list[str] = ["Dyad", "Session"], | |
): | |
""" | |
check_identity check if a transcript line and a reference line in the same session with same dyad. | |
Args: | |
rapport_df (pd.DataFrame): 2016 dataframe with rapport annotations | |
line (int): transcript row index number | |
reference (int): reference row index number | |
cols (list, optional): Defaults to ["Dyad", "Session"]. | |
Returns: | |
bool: if a transcript line and a reference line in the same session with same dyad, return True, else False. | |
""" | |
return ( | |
False | |
if False | |
in ((rapport_df.iloc[line][cols] == rapport_df.iloc[reference][cols]).tolist()) | |
else True | |
) | |
def get_segments(segment_idx: list(int)): | |
""" | |
get segments from a list of segment index. | |
example: [1, 2, 3, 5, 6, 8] -> [[1,2,3], [5,6], [8]] | |
Args: | |
segment_idx (list): list with indexes. | |
Returns: | |
list(list(int)): a list of segments index with list of int format. | |
""" | |
segments = [] | |
if len(segment_idx) != 0: | |
tmp = [segment_idx[0]] | |
seg = segment_idx[1:] | |
for i, item in enumerate(seg): | |
if item - 1 == tmp[-1]: | |
tmp.append(item) | |
else: | |
segments.append(tmp) | |
tmp = [] | |
tmp.append(item) | |
return segments | |
if __name__ == "__main__": | |
l1 = ["Tutor", "Tutee"] | |
l2 = ["P1", "P2"] | |
d = dict(zip(l1, l2)) | |
rprint(d) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment