Skip to content

Commit

Permalink
Add EmbedText support (#115)
Browse files Browse the repository at this point in the history
Adds support for both SNOWFLAKE.CORTEX.EMBED_TEXT_768 & SNOWFLAKE.CORTEX.EMBED_TEXT_1024 methods into the cortex python sdk.
  • Loading branch information
zbloss authored Sep 9, 2024
1 parent 2b044fc commit 0bdaf0b
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 0 deletions.
40 changes: 40 additions & 0 deletions snowflake/cortex/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,44 @@ py_test(
],
)

py_library(
name = "embed_text_768",
srcs = ["_embed_text_768.py"],
deps = [
":util",
"//snowflake/ml/_internal:telemetry",
],
)

py_test(
name = "embed_text_768_test",
srcs = ["embed_text_768_test.py"],
deps = [
":embed_text_768",
":test_util",
"//snowflake/ml/utils:connection_params",
],
)

py_library(
name = "embed_text_1024",
srcs = ["_embed_text_1024.py"],
deps = [
":util",
"//snowflake/ml/_internal:telemetry",
],
)

py_test(
name = "embed_text_1024_test",
srcs = ["embed_text_1024_test.py"],
deps = [
":embed_text_1024",
":test_util",
"//snowflake/ml/utils:connection_params",
],
)

py_library(
name = "init",
srcs = [
Expand All @@ -161,6 +199,8 @@ py_library(
deps = [
":classify_text",
":complete",
":embed_text_768",
":embed_text_1024",
":extract_answer",
":sentiment",
":summarize",
Expand Down
4 changes: 4 additions & 0 deletions snowflake/cortex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from snowflake.cortex._extract_answer import ExtractAnswer
from snowflake.cortex._sentiment import Sentiment
from snowflake.cortex._summarize import Summarize
from snowflake.cortex._embed_text_768 import EmbedText768
from snowflake.cortex._embed_text_1024 import EmbedText1024
from snowflake.cortex._translate import Translate

__all__ = [
"ClassifyText",
"Complete",
"CompleteOptions",
"EmbedText768",
"EmbedText1024",
"ExtractAnswer",
"Sentiment",
"Summarize",
Expand Down
42 changes: 42 additions & 0 deletions snowflake/cortex/_embed_text_1024.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional, Union

from snowflake import snowpark
from snowflake.cortex._util import (
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
call_sql_function,
)
from snowflake.ml._internal import telemetry


@telemetry.send_api_usage_telemetry(
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
)
def EmbedText1024(
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[list[float], snowpark.Column]:
"""TextEmbed calls into the LLM inference service to embed the text.
Args:
model: A Column of strings representing the model to use for embedding. The value
of the strings must be within the SUPPORTED_MODELS list.
text: A Column of strings representing input text.
session: The snowpark session to use. Will be inferred by context if not specified.
Returns:
A column of vectors containing embeddings.
"""

return _embed_text_1024_impl(
"snowflake.cortex.embed_text_1024", model, text, session=session
)


def _embed_text_1024_impl(
function: str,
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[list[float], snowpark.Column]:
return call_sql_function(function, session, model, text)
43 changes: 43 additions & 0 deletions snowflake/cortex/_embed_text_768.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Optional, Union, List

from snowflake import snowpark
from snowflake.cortex._util import (
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
SnowflakeConfigurationException,
call_sql_function,
)
from snowflake.ml._internal import telemetry


@telemetry.send_api_usage_telemetry(
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
)
def EmbedText768(
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[list[float], snowpark.Column]:
"""TextEmbed calls into the LLM inference service to embed the text.
Args:
model: A Column of strings representing the model to use for embedding. The value
of the strings must be within the SUPPORTED_MODELS list.
text: A Column of strings representing input text.
session: The snowpark session to use. Will be inferred by context if not specified.
Returns:
A column of vectors containing embeddings.
"""

return _embed_text_768_impl(
"snowflake.cortex.embed_text_768", model, text, session=session
)


def _embed_text_768_impl(
function: str,
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[list[float], snowpark.Column]:
return call_sql_function(function, session, model, text)
65 changes: 65 additions & 0 deletions snowflake/cortex/embed_text_1024_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

import _test_util
from absl.testing import absltest

from snowflake import snowpark
from snowflake.cortex import _embed_text_1024
from snowflake.snowpark import functions, types


class EmbedTest1024Test(absltest.TestCase):
model = "snowflake-arctic-embed-m"
text = "|text|"

@staticmethod
def embed_text_1024_for_test(model: str, text: str) -> List[float]:
return [0.0] * 1024

def setUp(self) -> None:
self._session = _test_util.create_test_session()
functions.udf(
self.embed_text_1024_for_test,
name="embed_text_1024",
session=self._session,
return_type=types.VectorType(float, 1024),
input_types=[types.StringType(), types.StringType()],
is_permanent=False,
)

def tearDown(self) -> None:
self._session.sql("drop function embed_text_1024(string,string)").collect()
self._session.close()

def test_embed_text_1024_str(self) -> None:
res = _embed_text_1024._embed_text_1024_impl(
"embed_text_1024",
self.model,
self.text,
session=self._session,
)
out = self.embed_text_1024_for_test(self.model, self.text)
self.assertEqual(
out, res
), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}"

def test_embed_text_1024_column(self) -> None:
df_in = self._session.create_dataframe(
[snowpark.Row(model=self.model, text=self.text)]
)
df_out = df_in.select(
_embed_text_1024._embed_text_1024_impl(
"embed_text_1024",
functions.col("model"),
functions.col("text"),
session=self._session,
)
)
res = df_out.collect()[0][0]
out = self.embed_text_1024_for_test(self.model, self.text)

self.assertEqual(out, res)


if __name__ == "__main__":
absltest.main()
65 changes: 65 additions & 0 deletions snowflake/cortex/embed_text_768_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

import _test_util
from absl.testing import absltest

from snowflake import snowpark
from snowflake.cortex import _embed_text_768
from snowflake.snowpark import functions, types


class EmbedTest768Test(absltest.TestCase):
model = "snowflake-arctic-embed-m"
text = "|text|"

@staticmethod
def embed_text_768_for_test(model: str, text: str) -> List[float]:
return [0.0] * 768

def setUp(self) -> None:
self._session = _test_util.create_test_session()
functions.udf(
self.embed_text_768_for_test,
name="embed_text_768",
session=self._session,
return_type=types.VectorType(float, 768),
input_types=[types.StringType(), types.StringType()],
is_permanent=False,
)

def tearDown(self) -> None:
self._session.sql("drop function embed_text_768(string,string)").collect()
self._session.close()

def test_embed_text_768_str(self) -> None:
res = _embed_text_768._embed_text_768_impl(
"embed_text_768",
self.model,
self.text,
session=self._session,
)
out = self.embed_text_768_for_test(self.model, self.text)
self.assertEqual(
out, res
), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}"

def test_embed_text_768_column(self) -> None:
df_in = self._session.create_dataframe(
[snowpark.Row(model=self.model, text=self.text)]
)
df_out = df_in.select(
_embed_text_768._embed_text_768_impl(
"embed_text_768",
functions.col("model"),
functions.col("text"),
session=self._session,
)
)
res = df_out.collect()[0][0]
out = self.embed_text_768_for_test(self.model, self.text)

self.assertEqual(out, res)


if __name__ == "__main__":
absltest.main()
6 changes: 6 additions & 0 deletions snowflake/cortex/package_visibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def test_complete_visible(self) -> None:
def test_extract_answer_visible(self) -> None:
self.assertTrue(callable(cortex.ExtractAnswer))

def test_embed_text_768_visible(self) -> None:
self.assertTrue(callable(cortex.EmbedText768))

def test_embed_text_1024_visible(self) -> None:
self.assertTrue(callable(cortex.EmbedText1024))

def test_sentiment_visible(self) -> None:
self.assertTrue(callable(cortex.Sentiment))

Expand Down

0 comments on commit 0bdaf0b

Please sign in to comment.