Skip to content

Commit

Permalink
Merge pull request #2 from oceanbase/obgeo
Browse files Browse the repository at this point in the history
  • Loading branch information
powerfooI authored Nov 5, 2024
2 parents 509796e + 2f5d03b commit 9c255ae
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ poetry install
- install with pip:

```shell
pip install pyobvector==0.1.7
pip install pyobvector==0.1.8
```

## Build Doc
Expand Down
21 changes: 20 additions & 1 deletion pyobvector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,23 @@
* ObSubHashPartition Specify Hash subpartition info
* ObKeyPartition Specify Key partition info
* ObSubKeyPartition Specify Key subpartition info
* ST_GeomFromText GIS function: parse text to geometry object
* st_distance GIS function: calculate distance between Points
* st_dwithin GIS function: check if the distance between two points
* st_astext GIS function: return a Point in human-readable format
"""
from .client import *
from .schema import VECTOR, VectorIndex, OceanBaseDialect, AsyncOceanBaseDialect
from .schema import (
VECTOR,
POINT,
VectorIndex,
OceanBaseDialect,
AsyncOceanBaseDialect,
ST_GeomFromText,
st_distance,
st_dwithin,
st_astext,
)

__all__ = [
"ObVecClient",
Expand All @@ -42,6 +56,7 @@
"IndexParams",
"DataType",
"VECTOR",
"POINT",
"VectorIndex",
"OceanBaseDialect",
"AsyncOceanBaseDialect",
Expand All @@ -58,4 +73,8 @@
"ObSubHashPartition",
"ObKeyPartition",
"ObSubKeyPartition",
"ST_GeomFromText",
"st_distance",
"st_dwithin",
"st_astext",
]
8 changes: 8 additions & 0 deletions pyobvector/client/ob_vec_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
cosine_distance,
inner_product,
negative_inner_product,
ST_GeomFromText,
st_distance,
st_dwithin,
st_astext,
ReplaceStmt,
)
from ..util import ObVersion
Expand Down Expand Up @@ -56,6 +60,10 @@ def __init__(
setattr(func_mod, "cosine_distance", cosine_distance)
setattr(func_mod, "inner_product", inner_product)
setattr(func_mod, "negative_inner_product", negative_inner_product)
setattr(func_mod, "ST_GeomFromText", ST_GeomFromText)
setattr(func_mod, "st_distance", st_distance)
setattr(func_mod, "st_dwithin", st_dwithin)
setattr(func_mod, "st_astext", st_astext)

connection_str = (
f"mysql+oceanbase://{user}:{password}@{uri}/{db_name}?charset=utf8mb4"
Expand Down
11 changes: 11 additions & 0 deletions pyobvector/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,35 @@
* inner_product New system function to calculate inner distance between vectors
* negative_inner_product
New system function to calculate neg ip distance between vectors
* ST_GeomFromText GIS function: parse text to geometry object
* st_distance GIS function: calculate distance between Points
* st_dwithin GIS function: check if the distance between two points
* st_astext GIS function: return a Point in human-readable format
* ReplaceStmt Replace into statement based on the extension of SQLAlchemy.Insert
"""
from .vector import VECTOR
from .geo_srid_point import POINT
from .vector_index import VectorIndex, CreateVectorIndex
from .ob_table import ObTable
from .vec_dist_func import l2_distance, cosine_distance, inner_product, negative_inner_product
from .gis_func import ST_GeomFromText, st_distance, st_dwithin, st_astext
from .replace_stmt import ReplaceStmt
from .dialect import OceanBaseDialect, AsyncOceanBaseDialect

__all__ = [
"VECTOR",
"POINT",
"VectorIndex",
"CreateVectorIndex",
"ObTable",
"l2_distance",
"cosine_distance",
"inner_product",
"negative_inner_product",
"ST_GeomFromText",
"st_distance",
"st_dwithin",
"st_astext",
"ReplaceStmt",
"OceanBaseDialect",
"AsyncOceanBaseDialect",
Expand Down
6 changes: 6 additions & 0 deletions pyobvector/schema/dialect.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""OceanBase dialect."""
from sqlalchemy import util
from sqlalchemy.dialects.mysql import aiomysql, pymysql

from .reflection import OceanBaseTableDefinitionParser
from .vector import VECTOR
from .geo_srid_point import POINT

class OceanBaseDialect(pymysql.MySQLDialect_pymysql):
# not change dialect name, since it is a subclass of pymysql.MySQLDialect_pymysql
# name = "oceanbase"
"""Ocenbase dialect."""
supports_statement_cache = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ischema_names["VECTOR"] = VECTOR
self.ischema_names["point"] = POINT

@util.memoized_property
def _tabledef_parser(self):
Expand All @@ -29,11 +33,13 @@ def _tabledef_parser(self):


class AsyncOceanBaseDialect(aiomysql.MySQLDialect_aiomysql):
"""OceanBase async dialect."""
supports_statement_cache = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ischema_names["VECTOR"] = VECTOR
self.ischema_names["point"] = POINT

@util.memoized_property
def _tabledef_parser(self):
Expand Down
38 changes: 38 additions & 0 deletions pyobvector/schema/geo_srid_point.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Point: OceanBase GIS data type for SQLAlchemy"""
from typing import Tuple, Optional
from sqlalchemy.types import UserDefinedType, String

class POINT(UserDefinedType):
"""Point data type definition."""
cache_ok = True
_string = String()

def __init__(
self,
# lat_long: Tuple[float, float],
srid: Optional[int] = None
):
"""Init Latitude and Longitude."""
super(UserDefinedType, self).__init__()
# self.lat_long = lat_long
self.srid = srid

def get_col_spec(self, **kw): # pylint: disable=unused-argument
"""Parse to Point data type definition in text SQL."""
if self.srid is None:
return "POINT"
return f"POINT SRID {self.srid}"

@classmethod
def to_db(cls, value: Tuple[float, float]):
"""Parse tuple to POINT literal"""
return f"POINT({value[0]} {value[1]})"

def bind_processor(self, dialect):
raise ValueError("Never access Point directly.")

def literal_processor(self, dialect):
raise ValueError("Never access Point directly.")

def result_processor(self, dialect, coltype):
raise ValueError("Never access Point directly.")
110 changes: 110 additions & 0 deletions pyobvector/schema/gis_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""gis_func: An extended system function in GIS."""

import logging
from typing import Tuple

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy import BINARY, Float, Boolean, Text

from .geo_srid_point import POINT

logger = logging.getLogger(__name__)

class ST_GeomFromText(FunctionElement):
"""ST_GeomFromText: parse text to geometry object.
Attributes:
type : result type
"""
type = BINARY()

def __init__(self, *args):
super().__init__()
self.args = args

@compiles(ST_GeomFromText)
def compile_ST_GeomFromText(element, compiler, **kwargs): # pylint: disable=unused-argument
"""Compile ST_GeomFromText function."""
args = []
for idx, arg in enumerate(element.args):
if idx == 0:
if (
(not isinstance(arg, Tuple)) or
(len(arg) != 2) or
(not all(isinstance(x, float) for x in arg))
):
raise ValueError(
f"Tuple[float, float] is expected for Point literal," \
f"while get {type(arg)}"
)
args.append(f"'{POINT.to_db(arg)}'")
else:
args.append(str(arg))
args_str = ", ".join(args)
# logger.info(f"{args_str}")
return f"ST_GeomFromText({args_str})"

class st_distance(FunctionElement):
"""st_distance: calculate distance between Points.
Attributes:
type : result type
"""
type = Float()
inherit_cache = True

def __init__(self, *args):
super().__init__()
self.args = args

@compiles(st_distance)
def compile_st_distance(element, compiler, **kwargs): # pylint: disable=unused-argument
"""Compile st_distance function."""
args = ", ".join(compiler.process(arg) for arg in element.args)
return f"st_distance({args})"

class st_dwithin(FunctionElement):
"""st_dwithin: Checks if the distance between two points
is less than a specified distance.
Attributes:
type : result type
"""
type = Boolean()
inherit_cache = True

def __init__(self, *args):
super().__init__()
self.args = args

@compiles(st_dwithin)
def compile_st_dwithin(element, compiler, **kwargs): # pylint: disable=unused-argument
"""Compile st_dwithin function."""
args = []
for idx, arg in enumerate(element.args):
if idx == 2:
args.append(str(arg))
else:
args.append(compiler.process(arg))
args_str = ", ".join(args)
return f"_st_dwithin({args_str})"

class st_astext(FunctionElement):
"""st_astext: Returns a Point in human-readable format.
Attributes:
type : result type
"""
type = Text()
inherit_cache = True

def __init__(self, *args):
super().__init__()
self.args = args

@compiles(st_astext)
def compile_st_astext(element, compiler, **kwargs): # pylint: disable=unused-argument
"""Compile st_astext function."""
args = ", ".join(compiler.process(arg) for arg in element.args)
return f"st_astext({args})"
8 changes: 4 additions & 4 deletions pyobvector/schema/reflection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""OceanBase table definition reflection."""
import re
import logging
from sqlalchemy.dialects.mysql.reflection import MySQLTableDefinitionParser, _re_compile

logger = logging.getLogger(__name__)

class OceanBaseTableDefinitionParser(MySQLTableDefinitionParser):
"""OceanBase table definition parser."""
def __init__(self, dialect, preparer, *, default_schema=None):
MySQLTableDefinitionParser.__init__(self, dialect, preparer)
self.default_schema = default_schema
Expand Down Expand Up @@ -32,7 +34,7 @@ def _prep_regexes(self):

self._re_key = _re_compile(
r" "
r"(?:(VECTOR|(?P<type>\S+)) )?KEY"
r"(?:(SPATIAL|VECTOR|(?P<type>\S+)) )?KEY"
# r"(?:(?P<type>\S+) )?KEY"
r"(?: +{iq}(?P<name>(?:{esc_fq}|[^{fq}])+){fq})?"
r"(?: +USING +(?P<using_pre>\S+))?"
Expand Down Expand Up @@ -66,8 +68,6 @@ def _prep_regexes(self):
def _parse_constraints(self, line):
"""Parse a CONSTRAINT line."""
ret = super()._parse_constraints(line)
# OceanBase show schema/database in foreign key constraint ddl, even if the schema/database is the default one
# logger.info(ret)
if ret:
tp, spec = ret
if tp == "partition":
Expand All @@ -81,4 +81,4 @@ def _parse_constraints(self, line):
spec["onupdate"] = None
if spec.get("ondelete", "").lower() == "restrict":
spec["ondelete"] = None
return ret
return ret
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyobvector"
version = "0.1.7"
version = "0.1.8"
description = "A python SDK for OceanBase Vector Store, based on SQLAlchemy, compatible with Milvus API."
authors = ["shanhaikang.shk <[email protected]>"]
readme = "README.md"
Expand Down
Loading

0 comments on commit 9c255ae

Please sign in to comment.