@@ -150,20 +150,39 @@ impl LoginDb {
150150 }
151151
152152 pub fn count_by_origin ( & self , origin : & str ) -> Result < i64 > {
153- let mut stmt = self . db . prepare_cached ( & COUNT_BY_ORIGIN_SQL ) ?;
154-
155- let count: i64 = stmt. query_row ( named_params ! { ":origin" : origin } , |row| row. get ( 0 ) ) ?;
156- Ok ( count)
153+ match LoginEntry :: validate_and_fixup_origin ( origin) {
154+ Ok ( result) => {
155+ let origin = result. unwrap_or ( origin. to_string ( ) ) ;
156+ let mut stmt = self . db . prepare_cached ( & COUNT_BY_ORIGIN_SQL ) ?;
157+ let count: i64 =
158+ stmt. query_row ( named_params ! { ":origin" : origin } , |row| row. get ( 0 ) ) ?;
159+ Ok ( count)
160+ }
161+ Err ( e) => {
162+ // don't log the input string as it's PII.
163+ warn ! ( "count_by_origin was passed an invalid origin: {}" , e) ;
164+ Ok ( 0 )
165+ }
166+ }
157167 }
158168
159169 pub fn count_by_form_action_origin ( & self , form_action_origin : & str ) -> Result < i64 > {
160- let mut stmt = self . db . prepare_cached ( & COUNT_BY_FORM_ACTION_ORIGIN_SQL ) ?;
161-
162- let count: i64 = stmt. query_row (
163- named_params ! { ":form_action_origin" : form_action_origin } ,
164- |row| row. get ( 0 ) ,
165- ) ?;
166- Ok ( count)
170+ match LoginEntry :: validate_and_fixup_origin ( form_action_origin) {
171+ Ok ( result) => {
172+ let form_action_origin = result. unwrap_or ( form_action_origin. to_string ( ) ) ;
173+ let mut stmt = self . db . prepare_cached ( & COUNT_BY_FORM_ACTION_ORIGIN_SQL ) ?;
174+ let count: i64 = stmt. query_row (
175+ named_params ! { ":form_action_origin" : form_action_origin } ,
176+ |row| row. get ( 0 ) ,
177+ ) ?;
178+ Ok ( count)
179+ }
180+ Err ( e) => {
181+ // don't log the input string as it's PII.
182+ warn ! ( "count_by_origin was passed an invalid origin: {}" , e) ;
183+ Ok ( 0 )
184+ }
185+ }
167186 }
168187
169188 pub fn get_all ( & self ) -> Result < Vec < EncryptedLogin > > {
@@ -632,6 +651,8 @@ impl LoginDb {
632651 // - Either `form_action_origin` or `http_realm` matches, depending on which one is non-null
633652 //
634653 // This is used for dupe-checking and `find_login_to_update()`
654+ //
655+ // Note that `entry` must be a normalized Login (via `fixup()`)
635656 fn get_by_entry_target ( & self , entry : & LoginEntry ) -> Result < Vec < EncryptedLogin > > {
636657 // Could be lazy_static-ed...
637658 lazy_static:: lazy_static! {
@@ -1213,11 +1234,24 @@ mod tests {
12131234 ..LoginEntry :: default ( )
12141235 } ;
12151236
1237+ let origin_umlaut = "https://bücher.example.com" ;
1238+ let login_umlaut = LoginEntry {
1239+ origin : origin_umlaut. into ( ) ,
1240+ http_realm : Some ( "https://www.example.com" . into ( ) ) ,
1241+ username : "test" . into ( ) ,
1242+ password : "sekret" . into ( ) ,
1243+ ..LoginEntry :: default ( )
1244+ } ;
1245+
12161246 let db = LoginDb :: open_in_memory ( ) ;
1217- db. add_many ( vec ! [ login_a. clone( ) , login_b. clone( ) ] , & * TEST_ENCDEC )
1218- . expect ( "should be able to add logins" ) ;
1247+ db. add_many (
1248+ vec ! [ login_a. clone( ) , login_b. clone( ) , login_umlaut. clone( ) ] ,
1249+ & * TEST_ENCDEC ,
1250+ )
1251+ . expect ( "should be able to add logins" ) ;
12191252
12201253 assert_eq ! ( db. count_by_origin( origin_a) . unwrap( ) , 1 ) ;
1254+ assert_eq ! ( db. count_by_origin( origin_umlaut) . unwrap( ) , 1 ) ;
12211255 }
12221256
12231257 #[ test]
@@ -1243,11 +1277,25 @@ mod tests {
12431277 ..LoginEntry :: default ( )
12441278 } ;
12451279
1280+ let origin_umlaut = "https://bücher.example.com" ;
1281+ let login_umlaut = LoginEntry {
1282+ origin : origin_umlaut. into ( ) ,
1283+ form_action_origin : Some ( origin_umlaut. into ( ) ) ,
1284+ http_realm : Some ( "https://www.example.com" . into ( ) ) ,
1285+ username : "test" . into ( ) ,
1286+ password : "sekret" . into ( ) ,
1287+ ..LoginEntry :: default ( )
1288+ } ;
1289+
12461290 let db = LoginDb :: open_in_memory ( ) ;
1247- db. add_many ( vec ! [ login_a. clone( ) , login_b. clone( ) ] , & * TEST_ENCDEC )
1248- . expect ( "should be able to add logins" ) ;
1291+ db. add_many (
1292+ vec ! [ login_a. clone( ) , login_b. clone( ) , login_umlaut. clone( ) ] ,
1293+ & * TEST_ENCDEC ,
1294+ )
1295+ . expect ( "should be able to add logins" ) ;
12491296
12501297 assert_eq ! ( db. count_by_form_action_origin( origin_a) . unwrap( ) , 1 ) ;
1298+ assert_eq ! ( db. count_by_form_action_origin( origin_umlaut) . unwrap( ) , 1 ) ;
12511299 }
12521300
12531301 #[ test]
@@ -1510,6 +1558,10 @@ mod tests {
15101558 vec ! [ "example.com" ] ,
15111559 vec ! [ "foo.com" ] ,
15121560 ) ;
1561+ }
1562+
1563+ #[ test]
1564+ fn test_get_by_base_domain_punicode ( ) {
15131565 // punycode! This is likely to need adjusting once we normalize
15141566 // on insert.
15151567 check_good_bad (
0 commit comments