3535import java .io .File ;
3636import java .io .FileInputStream ;
3737import java .io .FileOutputStream ;
38+ import java .io .IOException ;
39+ import java .io .InputStream ;
40+ import java .io .InvalidClassException ;
3841import java .io .ObjectInputStream ;
3942import java .io .ObjectOutputStream ;
43+ import java .io .ObjectStreamClass ;
44+ import java .io .Serializable ;
4045import java .util .LinkedList ;
4146import java .util .Queue ;
4247import java .util .concurrent .ConcurrentHashMap ;
@@ -63,13 +68,8 @@ private StateSaver() {
6368 * @param context used to get the available cache dir
6469 */
6570 public static void init (final Context context ) {
66- final File externalCacheDir = context .getExternalCacheDir ();
67- if (externalCacheDir != null ) {
68- cacheDirPath = externalCacheDir .getAbsolutePath ();
69- }
70- if (TextUtils .isEmpty (cacheDirPath )) {
71- cacheDirPath = context .getCacheDir ().getAbsolutePath ();
72- }
71+ // Use internal cache directory to prevent other apps from accessing/modifying the state
72+ cacheDirPath = context .getCacheDir ().getAbsolutePath ();
7373 }
7474
7575 /**
@@ -129,7 +129,7 @@ private static SavedState tryToRestore(@NonNull final SavedState savedState,
129129 }
130130
131131 try (FileInputStream fileInputStream = new FileInputStream (file );
132- ObjectInputStream inputStream = new ObjectInputStream (fileInputStream )) {
132+ ObjectInputStream inputStream = new ValidatingObjectInputStream (fileInputStream )) {
133133 //noinspection unchecked
134134 savedObjects = (Queue <Object >) inputStream .readObject ();
135135 }
@@ -310,6 +310,30 @@ public static void clearStateFiles() {
310310 }
311311 }
312312
313+ private static final class ValidatingObjectInputStream extends ObjectInputStream {
314+ ValidatingObjectInputStream (final InputStream in ) throws IOException {
315+ super (in );
316+ }
317+
318+ @ Override
319+ protected Class <?> resolveClass (final ObjectStreamClass desc )
320+ throws IOException , ClassNotFoundException {
321+ final String name = desc .getName ();
322+ if (!isSafe (name )) {
323+ throw new InvalidClassException ("Unauthorized deserialization attempt" , name );
324+ }
325+ return super .resolveClass (desc );
326+ }
327+
328+ private boolean isSafe (final String name ) {
329+ return name .startsWith ("java.lang." )
330+ || name .startsWith ("java.util." )
331+ || name .startsWith ("org.schabi.newpipe." )
332+ || name .startsWith ("[Ljava.lang." )
333+ || name .startsWith ("[Lorg.schabi.newpipe." );
334+ }
335+ }
336+
313337 /**
314338 * Used for describing how to save/read the objects.
315339 * <p>
0 commit comments