diff --git a/ddpo_pytorch/stat_tracking.py b/ddpo_pytorch/stat_tracking.py index 4199ab9..ee50034 100644 --- a/ddpo_pytorch/stat_tracking.py +++ b/ddpo_pytorch/stat_tracking.py @@ -9,6 +9,8 @@ class PerPromptStatTracker: self.stats = {} def update(self, prompts, rewards): + prompts = np.array(prompts) + rewards = np.array(rewards) unique = np.unique(prompts) advantages = np.empty_like(rewards) for prompt in unique: