-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds support for both SNOWFLAKE.CORTEX.EMBED_TEXT_768 & SNOWFLAKE.CORTEX.EMBED_TEXT_1024 methods into the cortex python sdk.
- Loading branch information
Showing
7 changed files
with
265 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Optional, Union | ||
|
||
from snowflake import snowpark | ||
from snowflake.cortex._util import ( | ||
CORTEX_FUNCTIONS_TELEMETRY_PROJECT, | ||
call_sql_function, | ||
) | ||
from snowflake.ml._internal import telemetry | ||
|
||
|
||
@telemetry.send_api_usage_telemetry( | ||
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, | ||
) | ||
def EmbedText1024( | ||
model: Union[str, snowpark.Column], | ||
text: Union[str, snowpark.Column], | ||
session: Optional[snowpark.Session] = None, | ||
) -> Union[list[float], snowpark.Column]: | ||
"""TextEmbed calls into the LLM inference service to embed the text. | ||
Args: | ||
model: A Column of strings representing the model to use for embedding. The value | ||
of the strings must be within the SUPPORTED_MODELS list. | ||
text: A Column of strings representing input text. | ||
session: The snowpark session to use. Will be inferred by context if not specified. | ||
Returns: | ||
A column of vectors containing embeddings. | ||
""" | ||
|
||
return _embed_text_1024_impl( | ||
"snowflake.cortex.embed_text_1024", model, text, session=session | ||
) | ||
|
||
|
||
def _embed_text_1024_impl( | ||
function: str, | ||
model: Union[str, snowpark.Column], | ||
text: Union[str, snowpark.Column], | ||
session: Optional[snowpark.Session] = None, | ||
) -> Union[list[float], snowpark.Column]: | ||
return call_sql_function(function, session, model, text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Optional, Union, List | ||
|
||
from snowflake import snowpark | ||
from snowflake.cortex._util import ( | ||
CORTEX_FUNCTIONS_TELEMETRY_PROJECT, | ||
SnowflakeConfigurationException, | ||
call_sql_function, | ||
) | ||
from snowflake.ml._internal import telemetry | ||
|
||
|
||
@telemetry.send_api_usage_telemetry( | ||
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, | ||
) | ||
def EmbedText768( | ||
model: Union[str, snowpark.Column], | ||
text: Union[str, snowpark.Column], | ||
session: Optional[snowpark.Session] = None, | ||
) -> Union[list[float], snowpark.Column]: | ||
"""TextEmbed calls into the LLM inference service to embed the text. | ||
Args: | ||
model: A Column of strings representing the model to use for embedding. The value | ||
of the strings must be within the SUPPORTED_MODELS list. | ||
text: A Column of strings representing input text. | ||
session: The snowpark session to use. Will be inferred by context if not specified. | ||
Returns: | ||
A column of vectors containing embeddings. | ||
""" | ||
|
||
return _embed_text_768_impl( | ||
"snowflake.cortex.embed_text_768", model, text, session=session | ||
) | ||
|
||
|
||
def _embed_text_768_impl( | ||
function: str, | ||
model: Union[str, snowpark.Column], | ||
text: Union[str, snowpark.Column], | ||
session: Optional[snowpark.Session] = None, | ||
) -> Union[list[float], snowpark.Column]: | ||
return call_sql_function(function, session, model, text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import List | ||
|
||
import _test_util | ||
from absl.testing import absltest | ||
|
||
from snowflake import snowpark | ||
from snowflake.cortex import _embed_text_1024 | ||
from snowflake.snowpark import functions, types | ||
|
||
|
||
class EmbedTest1024Test(absltest.TestCase): | ||
model = "snowflake-arctic-embed-m" | ||
text = "|text|" | ||
|
||
@staticmethod | ||
def embed_text_1024_for_test(model: str, text: str) -> List[float]: | ||
return [0.0] * 1024 | ||
|
||
def setUp(self) -> None: | ||
self._session = _test_util.create_test_session() | ||
functions.udf( | ||
self.embed_text_1024_for_test, | ||
name="embed_text_1024", | ||
session=self._session, | ||
return_type=types.VectorType(float, 1024), | ||
input_types=[types.StringType(), types.StringType()], | ||
is_permanent=False, | ||
) | ||
|
||
def tearDown(self) -> None: | ||
self._session.sql("drop function embed_text_1024(string,string)").collect() | ||
self._session.close() | ||
|
||
def test_embed_text_1024_str(self) -> None: | ||
res = _embed_text_1024._embed_text_1024_impl( | ||
"embed_text_1024", | ||
self.model, | ||
self.text, | ||
session=self._session, | ||
) | ||
out = self.embed_text_1024_for_test(self.model, self.text) | ||
self.assertEqual( | ||
out, res | ||
), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" | ||
|
||
def test_embed_text_1024_column(self) -> None: | ||
df_in = self._session.create_dataframe( | ||
[snowpark.Row(model=self.model, text=self.text)] | ||
) | ||
df_out = df_in.select( | ||
_embed_text_1024._embed_text_1024_impl( | ||
"embed_text_1024", | ||
functions.col("model"), | ||
functions.col("text"), | ||
session=self._session, | ||
) | ||
) | ||
res = df_out.collect()[0][0] | ||
out = self.embed_text_1024_for_test(self.model, self.text) | ||
|
||
self.assertEqual(out, res) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import List | ||
|
||
import _test_util | ||
from absl.testing import absltest | ||
|
||
from snowflake import snowpark | ||
from snowflake.cortex import _embed_text_768 | ||
from snowflake.snowpark import functions, types | ||
|
||
|
||
class EmbedTest768Test(absltest.TestCase): | ||
model = "snowflake-arctic-embed-m" | ||
text = "|text|" | ||
|
||
@staticmethod | ||
def embed_text_768_for_test(model: str, text: str) -> List[float]: | ||
return [0.0] * 768 | ||
|
||
def setUp(self) -> None: | ||
self._session = _test_util.create_test_session() | ||
functions.udf( | ||
self.embed_text_768_for_test, | ||
name="embed_text_768", | ||
session=self._session, | ||
return_type=types.VectorType(float, 768), | ||
input_types=[types.StringType(), types.StringType()], | ||
is_permanent=False, | ||
) | ||
|
||
def tearDown(self) -> None: | ||
self._session.sql("drop function embed_text_768(string,string)").collect() | ||
self._session.close() | ||
|
||
def test_embed_text_768_str(self) -> None: | ||
res = _embed_text_768._embed_text_768_impl( | ||
"embed_text_768", | ||
self.model, | ||
self.text, | ||
session=self._session, | ||
) | ||
out = self.embed_text_768_for_test(self.model, self.text) | ||
self.assertEqual( | ||
out, res | ||
), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" | ||
|
||
def test_embed_text_768_column(self) -> None: | ||
df_in = self._session.create_dataframe( | ||
[snowpark.Row(model=self.model, text=self.text)] | ||
) | ||
df_out = df_in.select( | ||
_embed_text_768._embed_text_768_impl( | ||
"embed_text_768", | ||
functions.col("model"), | ||
functions.col("text"), | ||
session=self._session, | ||
) | ||
) | ||
res = df_out.collect()[0][0] | ||
out = self.embed_text_768_for_test(self.model, self.text) | ||
|
||
self.assertEqual(out, res) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters