Yes I optimized my network with tensor RT, It is a mask rcnn where the backbone is optimized cause the last layers are not optimizeable. I need more than 1 Hz have a depth publishing rate of 15 Hz and camera image publishing framerate of 15 Hz enabled on the zed. With more I think the bottle neck is somewhere else
Here I can show you my loading in of trained model
def _load_model(self):
if not self.weights or not os.path.exists(self.weights):
self.get_logger().error(f’weights_path nicht gefunden: {self.weights}')
raise FileNotFoundError(self.weights)
ckpt = torch.load(self.weights, map_location=self.device)
if isinstance(ckpt, dict) and 'model_state' in ckpt:
state_dict = ckpt['model_state']
num_classes = ckpt.get('num_classes', len(LABEL_ID_TO_NAME) + 1)
klassen = ckpt.get('klassen', None)
mean = ckpt.get('mean', None)
std = ckpt.get('std', None)
self.get_logger().info(
f'Checkpoint: Epoche {ckpt.get("best_epoch","?")} | Acc {ckpt.get("best_acc", 0):.4f}')
else:
state_dict = ckpt
num_classes = len(LABEL_ID_TO_NAME) + 1
klassen = None
mean = std = None
if not klassen:
klassen = [LABEL_ID_TO_NAME[i] for i in sorted(LABEL_ID_TO_NAME)]
self.get_logger().warn('Keine klassen im Checkpoint — Fallback auf class_config.py')
self.klassen = klassen
self.num_classes = num_classes
for label_id, name in enumerate(self.klassen, start=1):
if name in HOSE_CLASS_NAMES:
self.hose_classes.add(label_id)
elif name in FORBIDDEN_CLASS_NAMES:
self.forbidden_classes.add(label_id)
if not self.hose_classes:
self.get_logger().warn(f'Keine Hose-Klassen! HOSE_CLASS_NAMES={HOSE_CLASS_NAMES}')
model = Modell().get_model(
'Maskrcnnv2', num_classes, self.device,
only_last_layers_train=False, version=0, mean=mean, std=std)
model.load_state_dict(state_dict, strict=True)
# ── NEU: Jetson-Parameter aus Checkpoint wiederherstellen ─────────
jetson_cfg = ckpt.get('jetson_cfg', {})
if jetson_cfg:
jw, jh = ckpt.get('img_size', (1280, 720))
model.roi_heads.score_thresh = 0.5
model.roi_heads.detections_per_img = 30
model.rpn.pre_nms_top_n_test = 500
model.rpn.post_nms_top_n_test = 100
model.transform.min_size = (jh,)
model.transform.max_size = jw
model.transform.fixed_size = (jw, jh)
# # ── NEU: inf_size aus Checkpoint übernehmen ───────────────
# # fixed_size und inf_size müssen übereinstimmen
# self.inf_size = jw # ← Node resized auf genau diese Breite
# self.get_logger().info(
# f'Jetson-Modus aktiv: {jw}×{jh} '
# f'score_thresh=0.5 proposals=500/100 '
# f'inf_size automatisch → {jw}')
self.get_logger().info(
f'Jetson-Modus aktiv: {jw}×{jh} score_thresh=0.5 proposals=500/100')
else:
# Normales Modell — score_threshold aus ROS-Parameter
model.roi_heads.score_thresh = self.score_th
self.get_logger().info('Standard-Modell geladen (kein Jetson-Profil)')
# ── Internen Transform fixieren — verhindert doppeltes Resize ─────────
# Mask R-CNN resized intern nochmal auf min_size=800 — das bricht TRT.
# Wir fixieren auf unsere inference_size damit Backbone statische Shapes hat.
# if self.inf_size is not None:
# h_inf = int(720 * self.inf_size / 1280)
# model.transform.min_size = (h_inf,)
# model.transform.max_size = self.inf_size
# self.get_logger().info(f'Transform fixiert auf {h_inf}×{self.inf_size}')
model.to(self.device).eval().half()
# ── TensorRT Backbone ─────────────────────────────────────────────────
trt_cache = self.weights.replace('.pth', '_backbone_trt.pth')
try:
from torch2trt import torch2trt, TRTModule
if os.path.exists(trt_cache):
self.get_logger().info('Lade TRT-Backbone aus Cache...')
backbone_trt = TRTModule()
backbone_trt.load_state_dict(torch.load(trt_cache))
model.backbone = backbone_trt
self.get_logger().info('TRT-Backbone geladen.')
else:
self.get_logger().info('Konvertiere Backbone → TensorRT (einmalig ~3 Min)...')
# Echte Backbone-Input-Größe nach internem Transform
# Mask R-CNN resized intern auf min_size=800 → bei 1280×720 ergibt das 768×1344
dummy = torch.zeros(1, 3, 768, 1344, device=self.device).half()
backbone_trt = torch2trt(
model.backbone,
[dummy],
fp16_mode=True,
max_workspace_size=1 << 28)
torch.save(backbone_trt.state_dict(), trt_cache)
model.backbone = backbone_trt
self.get_logger().info(f'TRT-Backbone gespeichert: {trt_cache}')
except Exception as e:
self.get_logger().warn(f'TRT fehlgeschlagen, nutze FP16: {e}')
self._use_fp16 = True
return model