|
24 | 24 | using namespace stim; |
25 | 25 | using namespace stim_pybind; |
26 | 26 |
|
| 27 | + |
27 | 28 | pybind11::object flex_pauli_string_to_unitary_matrix(const stim::FlexPauliString &ps, std::string_view endian) { |
28 | 29 | bool little_endian; |
29 | 30 | if (endian == "little") { |
@@ -323,6 +324,180 @@ pybind11::class_<FlexPauliString> stim_pybind::pybind_pauli_string(pybind11::mod |
323 | 324 | .data()); |
324 | 325 | } |
325 | 326 |
|
| 327 | +static uint8_t convert_pauli_to_int(const pybind11::handle &h) { |
| 328 | + int64_t v = -1; |
| 329 | + if (pybind11::isinstance<pybind11::int_>(h)) { |
| 330 | + try { |
| 331 | + v = pybind11::cast<int64_t>(h); |
| 332 | + } catch (const pybind11::cast_error &) { |
| 333 | + } |
| 334 | + } else if (pybind11::isinstance<pybind11::str>(h)) { |
| 335 | + std::string_view s = pybind11::cast<std::string_view>(h); |
| 336 | + if (s == "I" || s == "_") { |
| 337 | + v = 0; |
| 338 | + } else if (s == "X" || s == "x") { |
| 339 | + v = 1; |
| 340 | + } else if (s == "Y" || s == "y") { |
| 341 | + v = 2; |
| 342 | + } else if (s == "Z" || s == "z") { |
| 343 | + v = 3; |
| 344 | + } |
| 345 | + } |
| 346 | + if (v >= 0 && v < 4) { |
| 347 | + return (uint8_t)v; |
| 348 | + } else { |
| 349 | + throw std::invalid_argument( |
| 350 | + "Don't know how to convert " + pybind11::cast<std::string>(pybind11::repr(h)) + |
| 351 | + " into a pauli.\n" |
| 352 | + "Expected something from {0, 1, 2, 3, 'I', 'X', 'Y', 'Z', '_'}."); |
| 353 | + } |
| 354 | +} |
| 355 | + |
| 356 | +static FlexPauliString pauli_string_from_dict(const pybind11::dict& dict) { |
| 357 | + // Handle empty dict: |
| 358 | + if (dict.empty()) { |
| 359 | + return FlexPauliString(0); |
| 360 | + } |
| 361 | + |
| 362 | + const auto &first_entry = dict.begin(); |
| 363 | + std::vector<std::pair<size_t, uint8_t>> pauli_by_location; |
| 364 | + size_t max_index = 0; |
| 365 | + |
| 366 | + auto add_pauli_to_index = [&pauli_by_location, &max_index](pybind11::handle index, uint8_t pauli) { |
| 367 | + int64_t index_value = pybind11::cast<int64_t>(index); |
| 368 | + if (index_value < 0) { |
| 369 | + throw std::invalid_argument( |
| 370 | + "Qubit index must be non-negative. got: " + std::to_string(index_value)); |
| 371 | + } |
| 372 | + |
| 373 | + size_t index_ = static_cast<size_t>(index_value); |
| 374 | + if (index_ > max_index) { |
| 375 | + max_index = index_; |
| 376 | + } |
| 377 | + pauli_by_location.push_back(std::make_pair(index_, pauli)); |
| 378 | + }; |
| 379 | + |
| 380 | + if (pybind11::isinstance<pybind11::int_>(first_entry->second) || |
| 381 | + pybind11::isinstance<pybind11::str>(first_entry->second)) { |
| 382 | + // Value is int or str -> key is qubit index: |
| 383 | + for (const auto &item : dict) { |
| 384 | + const auto &index = item.first; |
| 385 | + const auto &pauli_string = item.second; |
| 386 | + // Verify index is int for consistency: |
| 387 | + if (!pybind11::isinstance<pybind11::int_>(index)) { |
| 388 | + throw std::invalid_argument( |
| 389 | + "When constructing stim.PauliString from Dict, keys must all be ints (indices) with single Pauli values, or Pauli keys with iterable values. Conflicting key: " + |
| 390 | + pybind11::cast<std::string>(pybind11::repr(index))); |
| 391 | + } |
| 392 | + |
| 393 | + add_pauli_to_index(index, convert_pauli_to_int(pauli_string)); |
| 394 | + } |
| 395 | + |
| 396 | + } else if (pybind11::isinstance<pybind11::iterable>(first_entry->second)) { |
| 397 | + // Value is iterable -> key is Pauli: |
| 398 | + |
| 399 | + // Find maximum number of indices: |
| 400 | + size_t max_expected_indices = 0; |
| 401 | + for (const auto &item : dict) { |
| 402 | + const auto &indices = item.second; |
| 403 | + if (pybind11::isinstance<pybind11::iterable>(indices)) { |
| 404 | + max_expected_indices += static_cast<size_t>(pybind11::len(indices)); |
| 405 | + } else { |
| 406 | + // In the iterable case - all values must also be iterables: |
| 407 | + throw std::invalid_argument( |
| 408 | + "When constructing stim.PauliString from Dict, with values as iterables, all values must be iterables. got: " + |
| 409 | + pybind11::cast<std::string>(pybind11::repr(indices))); |
| 410 | + } |
| 411 | + } |
| 412 | + |
| 413 | + std::unordered_map<size_t, uint8_t> used_indices; |
| 414 | + used_indices.reserve(max_expected_indices); |
| 415 | + |
| 416 | + auto verify_index_not_used = [&used_indices](pybind11::handle index, uint8_t pauli) -> bool { |
| 417 | + // This helper function checks if an index has been used before. It doesn't allow non-trivial |
| 418 | + // Pauli strings to collide, but it does allow collisions with trivial ("I") Pauli strings by not keeping track of them. |
| 419 | + // return true if the new index should be added to the final result. |
| 420 | + |
| 421 | + int64_t index_value = pybind11::cast<int64_t>(index); |
| 422 | + if (index_value < 0) { |
| 423 | + throw std::invalid_argument( |
| 424 | + "Qubit index must be non-negative. got: " + std::to_string(index_value)); |
| 425 | + } |
| 426 | + |
| 427 | + size_t index_size_t = static_cast<size_t>(index_value); |
| 428 | + const auto index_found = used_indices.find(index_size_t); |
| 429 | + |
| 430 | + if (index_found == used_indices.end()) { |
| 431 | + // Index has not been seed yet - add to map if non-trivial: |
| 432 | + if (pauli != 0) { |
| 433 | + used_indices.emplace(index_size_t, pauli); |
| 434 | + } |
| 435 | + } else { |
| 436 | + // Index has been seen before |
| 437 | + if (pauli == 0) { |
| 438 | + // Must not add that Pauli to final result, it will override the older non-trivial one. |
| 439 | + return false; |
| 440 | + } |
| 441 | + |
| 442 | + // Check if older index is not the same Pauli: |
| 443 | + if (index_found->second != pauli) { |
| 444 | + throw std::invalid_argument( |
| 445 | + "More than one Pauli definitions use the same qubit index. Conflict for index:" + |
| 446 | + pybind11::cast<std::string>(pybind11::repr(index))); |
| 447 | + } |
| 448 | + } |
| 449 | + |
| 450 | + // Add new inde to final result: |
| 451 | + return true; |
| 452 | + }; |
| 453 | + |
| 454 | + for (const auto &item : dict) { |
| 455 | + const auto &pauli_str_or_int = item.first; |
| 456 | + const auto &indices = item.second; |
| 457 | + |
| 458 | + // Verify pauli_str_or_int is str or int for consistency: |
| 459 | + if (!(pybind11::isinstance<pybind11::str>(pauli_str_or_int) || |
| 460 | + pybind11::isinstance<pybind11::int_>(pauli_str_or_int))) { |
| 461 | + throw std::invalid_argument( |
| 462 | + "When constructing stim.PauliString from Dict, keys must all be ints (indices) with single Pauli values, or Pauli keys with iterable values. Conflicting key: " + |
| 463 | + pybind11::cast<std::string>(pybind11::repr(pauli_str_or_int))); |
| 464 | + } |
| 465 | + |
| 466 | + for (const auto &qubit_index : indices) { |
| 467 | + // Verify index is an int: |
| 468 | + if (!pybind11::isinstance<pybind11::int_>(qubit_index)) { |
| 469 | + throw std::invalid_argument( |
| 470 | + "Qubit index must be an int. got:" + |
| 471 | + pybind11::cast<std::string>(pybind11::repr(qubit_index))); |
| 472 | + } |
| 473 | + |
| 474 | + uint8_t pauli = convert_pauli_to_int(pauli_str_or_int); |
| 475 | + bool should_add_new_pauli = verify_index_not_used(qubit_index, pauli); |
| 476 | + if (should_add_new_pauli) { |
| 477 | + add_pauli_to_index(qubit_index, pauli); |
| 478 | + } |
| 479 | + } |
| 480 | + } |
| 481 | + } else { |
| 482 | + throw std::invalid_argument( |
| 483 | + "Don't know how to initialize a stim.PauliString from " + |
| 484 | + pybind11::cast<std::string>(pybind11::repr(dict))); |
| 485 | + } |
| 486 | + |
| 487 | + // Format collected info into a FlexPauliString: |
| 488 | + FlexPauliString result(pauli_by_location.empty() ? 0 : max_index+1); |
| 489 | + |
| 490 | + for (const auto &[key, value] : pauli_by_location) { |
| 491 | + // Conver 0-3 to x,z values (00, 01, 10, 11) |
| 492 | + uint8_t p = value; |
| 493 | + p ^= p >> 1; |
| 494 | + result.value.xs[key] = p & 1; |
| 495 | + result.value.zs[key] = (p & 2) >> 1; |
| 496 | + } |
| 497 | + |
| 498 | + return result; |
| 499 | +} |
| 500 | + |
326 | 501 | void stim_pybind::pybind_pauli_string_methods(pybind11::module &m, pybind11::class_<FlexPauliString> &c) { |
327 | 502 | c.def( |
328 | 503 | pybind11::init( |
@@ -359,36 +534,15 @@ void stim_pybind::pybind_pauli_string_methods(pybind11::module &m, pybind11::cla |
359 | 534 | return pybind11::cast<FlexPauliString>(other_or); |
360 | 535 | } |
361 | 536 |
|
| 537 | + if (pybind11::isinstance<pybind11::dict>(arg)) { |
| 538 | + return pauli_string_from_dict(pybind11::cast<pybind11::dict>(arg)); |
| 539 | + } |
| 540 | + |
362 | 541 | pybind11::object pauli_indices_or = pybind11::isinstance<pybind11::iterable>(arg) ? arg : pauli_indices; |
363 | 542 | if (!pauli_indices_or.is_none()) { |
364 | 543 | std::vector<uint8_t> ps; |
365 | 544 | for (const pybind11::handle &h : pauli_indices_or) { |
366 | | - int64_t v = -1; |
367 | | - if (pybind11::isinstance<pybind11::int_>(h)) { |
368 | | - try { |
369 | | - v = pybind11::cast<int64_t>(h); |
370 | | - } catch (const pybind11::cast_error &) { |
371 | | - } |
372 | | - } else if (pybind11::isinstance<pybind11::str>(h)) { |
373 | | - std::string_view s = pybind11::cast<std::string_view>(h); |
374 | | - if (s == "I" || s == "_") { |
375 | | - v = 0; |
376 | | - } else if (s == "X" || s == "x") { |
377 | | - v = 1; |
378 | | - } else if (s == "Y" || s == "y") { |
379 | | - v = 2; |
380 | | - } else if (s == "Z" || s == "z") { |
381 | | - v = 3; |
382 | | - } |
383 | | - } |
384 | | - if (v >= 0 && v < 4) { |
385 | | - ps.push_back((uint8_t)v); |
386 | | - } else { |
387 | | - throw std::invalid_argument( |
388 | | - "Don't know how to convert " + pybind11::cast<std::string>(pybind11::repr(h)) + |
389 | | - " into a pauli.\n" |
390 | | - "Expected something from {0, 1, 2, 3, 'I', 'X', 'Y', 'Z', '_'}."); |
391 | | - } |
| 545 | + ps.push_back(convert_pauli_to_int(h)); |
392 | 546 | } |
393 | 547 | FlexPauliString result(ps.size()); |
394 | 548 | for (size_t k = 0; k < ps.size(); k++) { |
@@ -433,6 +587,14 @@ void stim_pybind::pybind_pauli_string_methods(pybind11::module &m, pybind11::cla |
433 | 587 | Iterable: initializes by interpreting each item as a Pauli. |
434 | 588 | Each item can be a single-qubit Pauli string (like "X"), |
435 | 589 | or an integer. Integers use the convention 0=I, 1=X, 2=Y, 3=Z. |
| 590 | + Dict[int, Union[int, str]]: initializes by interpreting keys as |
| 591 | + the qubit index and values as the Pauli for that index. |
| 592 | + Each value can be a single-qubit Pauli string (like "X"), |
| 593 | + or an integer. Integers use the convention 0=I, 1=X, 2=Y, 3=Z. |
| 594 | + Dict[Union[int, str], Iterable[int]]: initializes by interpreting keys |
| 595 | + as Pauli operators and values as the qubit indices for that Pauli. |
| 596 | + Each key can be a single-qubit Pauli string (like "X"), |
| 597 | + or an integer. Integers use the convention 0=I, 1=X, 2=Y, 3=Z. |
436 | 598 |
|
437 | 599 | Examples: |
438 | 600 | >>> import stim |
@@ -460,6 +622,15 @@ void stim_pybind::pybind_pauli_string_methods(pybind11::module &m, pybind11::cla |
460 | 622 |
|
461 | 623 | >>> stim.PauliString("X6*Y6") |
462 | 624 | stim.PauliString("+i______Z") |
| 625 | +
|
| 626 | + >>> stim.PauliString({0: "X", 2: "Y", 3: "X"}) |
| 627 | + stim.PauliString("+X_YX") |
| 628 | +
|
| 629 | + >>> stim.PauliString({0: "X", 2: 2, 3: 1}) |
| 630 | + stim.PauliString("+X_YX") |
| 631 | +
|
| 632 | + >>> stim.PauliString({"X": [1], 2: [4], "Z": [0, 3]}) |
| 633 | + stim.PauliString("+ZX_ZY") |
463 | 634 | )DOC") |
464 | 635 | .data()); |
465 | 636 |
|
|
0 commit comments