Last active
July 26, 2023 04:10
-
-
Save weimzh/6a95a73ff2f9c6fc06f19a5d31795e80 to your computer and use it in GitHub Desktop.
bag2cyber.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# Convert entire rosbag to Apollo Cyber Record format. | |
# | |
# Copyright (c) 2023, Wei Mingzhi <whistler_wmz@users.sf.net>. | |
# | |
# SPDX-License-Identifier: BSD-3-Clause | |
# https://spdx.org/licenses/BSD-3-Clause.html | |
# | |
import sys | |
import rosbag | |
import sys | |
import os | |
import re | |
from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto, FileDescriptorProto | |
from google.protobuf.descriptor_pool import DescriptorPool | |
from google.protobuf.message_factory import GetMessages | |
from cyber_record import record | |
basetype_map = { | |
'bool': FieldDescriptorProto.TYPE_BOOL, | |
'int8': FieldDescriptorProto.TYPE_SINT32, | |
'uint8': FieldDescriptorProto.TYPE_UINT32, | |
'int16': FieldDescriptorProto.TYPE_SINT32, | |
'uint16': FieldDescriptorProto.TYPE_UINT32, | |
'int32': FieldDescriptorProto.TYPE_SINT32, | |
'uint32': FieldDescriptorProto.TYPE_UINT32, | |
'int64': FieldDescriptorProto.TYPE_SINT64, | |
'uint64': FieldDescriptorProto.TYPE_UINT64, | |
'float32': FieldDescriptorProto.TYPE_FLOAT, | |
'float64': FieldDescriptorProto.TYPE_DOUBLE, | |
'string': FieldDescriptorProto.TYPE_STRING, | |
'time': FieldDescriptorProto.TYPE_UINT64, | |
'duration': FieldDescriptorProto.TYPE_SINT64 | |
} | |
def read_bag_types(bag_path): | |
bag = rosbag.Bag(bag_path) | |
types = {} | |
topics = {} | |
for topic, msg, t in bag.read_messages(): | |
if not msg._type in types: | |
types[msg._type] = msg._full_text | |
topics[topic] = msg._type | |
bag.close() | |
packages = {} | |
for t in types: | |
l = types[t].split('\n') | |
msg_package = t.split('/')[0] | |
if msg_package not in packages: | |
packages[msg_package] = {} | |
msg_name = t.split('/')[1] | |
packages[msg_package][msg_name] = {} | |
for line in l: | |
if line.startswith('==='): | |
continue | |
elif line.startswith('MSG: '): | |
m = line.split(' ')[1] | |
msg_package = m.split('/')[0] | |
if msg_package not in packages: | |
packages[msg_package] = {} | |
msg_name = m.split('/')[1] | |
packages[msg_package][msg_name] = {} | |
else: | |
line = line.split('#')[0].strip() | |
if len(line) == 0: | |
continue | |
if '=' in line: | |
# const value, ignore for now | |
continue | |
msg_type = re.split('[ \t]', line)[0] | |
msg_member = re.split('[ \t]', line)[-1] | |
packages[msg_package][msg_name][msg_member] = msg_type | |
return (packages, topics) | |
def generate_proto(packages): | |
fd = FileDescriptorProto() | |
fd.name = 'apollo' | |
fd.package = 'apollo' | |
for package in packages: | |
for tp in packages[package]: | |
msg_type = fd.message_type.add() | |
msg_type.name = package + '__' + tp | |
for field in packages[package][tp]: | |
field_type = packages[package][tp][field] | |
field_proto = msg_type.field.add() | |
field_proto.name = field | |
field_proto.number = len(msg_type.field) | |
if '[' in field_type: | |
field_proto.label = FieldDescriptorProto.LABEL_REPEATED | |
field_type = field_type.split('[')[0] | |
else: | |
field_proto.label = FieldDescriptorProto.LABEL_OPTIONAL | |
if field_type in basetype_map: | |
if field_type == 'uint8' and field_proto.label == FieldDescriptorProto.LABEL_REPEATED: | |
# use bytes for uint8 array | |
field_proto.label = FieldDescriptorProto.LABEL_OPTIONAL | |
field_proto.type = FieldDescriptorProto.TYPE_BYTES | |
else: | |
field_proto.type = basetype_map[field_type] | |
else: | |
field_proto.type = FieldDescriptorProto.TYPE_MESSAGE | |
if '/' in field_type: | |
field_proto.type_name = field_type.replace('/', '__') | |
else: | |
if field_type in packages[package]: | |
field_proto.type_name = package + '__' + field_type | |
else: | |
field_proto.type_name = 'std_msgs__' + field_type | |
return fd | |
def convert_to_rosmsg(packages, msg_classes, ros_msg): | |
tp = ros_msg._type | |
msg_package, msg_tp = tp.split('/') | |
protobuf_tp = tp.replace('/', '__') | |
proto_msg = msg_classes['apollo.' + protobuf_tp]() | |
msg_struct = packages[msg_package][msg_tp] | |
for member in msg_struct: | |
member_type = msg_struct[member] | |
if member_type in basetype_map: | |
if member_type == 'time': | |
proto_msg.__setattr__(member, ros_msg.__getattribute__(member).to_nsec()) | |
elif member_type == 'duration': | |
proto_msg.__setattr__(member, ros_msg.__getattribute__(member).to_nsec()) | |
else: | |
proto_msg.__setattr__(member, ros_msg.__getattribute__(member)) | |
elif '[' in member_type: | |
array_type = member_type.split('[')[0] | |
if array_type in basetype_map: | |
arr = ros_msg.__getattribute__(member) | |
if array_type == 'uint8': | |
proto_msg.__setattr__(member, ros_msg.__getattribute__(member)) | |
else: | |
for elem in arr: | |
if array_type == 'time': | |
proto_msg.__getattribute__(member).append(elem.to_nsec()) | |
elif array_type == 'duration': | |
proto_msg.__getattribute__(member).append(elem.to_nsec()) | |
else: | |
proto_msg.__getattribute__(member).append(elem) | |
else: | |
arr = ros_msg.__getattribute__(member) | |
for elem in arr: | |
proto_msg.__getattribute__(member).add().CopyFrom(convert_to_rosmsg(packages, msg_classes, elem)) | |
else: | |
proto_msg.__getattribute__(member).CopyFrom(convert_to_rosmsg(packages, msg_classes, ros_msg.__getattribute__(member))) | |
return proto_msg | |
def convert_bag_to_record(in_bag, out_record_path, ignore_sensor=False): | |
packages, topics = read_bag_types(in_bag) | |
fd = generate_proto(packages) | |
fd_str = fd.SerializeToString() | |
out_record = record.RecordWriter() | |
out_record.open(out_record_path) | |
for topic in topics: | |
if topics[topic].startswith('sensor_msgs/') and ignore_sensor: | |
continue | |
out_record.write_channel(topic, topics[topic], fd_str) | |
message_classes = GetMessages([fd]) | |
bag = rosbag.Bag(in_bag) | |
for topic, msg, t in bag.read_messages(): | |
if msg._type.startswith('sensor_msgs/') and ignore_sensor: | |
continue | |
converted = convert_to_rosmsg(packages, message_classes, msg) | |
out_record.write_message(topic, converted.SerializeToString(), t.to_nsec()) | |
out_record.close() | |
if __name__ == '__main__': | |
convert_bag_to_record(sys.argv[1], sys.argv[2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment