Skip to content

Commit

Permalink
Merge pull request #132 from myyrakle/feat/#15
Browse files Browse the repository at this point in the history
[#15] Request Parameter 자동 바인딩 구현
  • Loading branch information
myyrakle authored Aug 26, 2024
2 parents 2fba901 + 63b8835 commit 44ef9f6
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 0 deletions.
4 changes: 4 additions & 0 deletions rupring/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,7 @@ impl<T: IModule + Clone + Copy + Sync + Send + 'static> RupringFactory<T> {

#[cfg(test)]
mod test_proc_macro;

pub use anyhow;
pub use serde;
pub use serde_json;
222 changes: 222 additions & 0 deletions rupring/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ pub struct Request {
pub(crate) di_context: Arc<crate::DIContext>,
}

pub trait BindFromRequest {
fn bind(&mut self, request: Request) -> anyhow::Result<Self>
where
Self: Sized;
}

impl UnwindSafe for Request {}

impl Request {
Expand All @@ -21,6 +27,222 @@ impl Request {
}
}

#[derive(Debug, Clone)]
pub struct QueryString(pub Vec<String>);

pub trait QueryStringDeserializer<T>: Sized {
type Error;

fn deserialize_query_string(&self) -> Result<T, Self::Error>;
}

impl<T> QueryStringDeserializer<Option<T>> for QueryString
where
QueryString: QueryStringDeserializer<T>,
{
type Error = ();

fn deserialize_query_string(&self) -> Result<Option<T>, Self::Error> {
let result = Self::deserialize_query_string(self);
match result {
Ok(v) => Ok(Some(v)),
Err(_) => Ok(None),
}
}
}

impl QueryStringDeserializer<i8> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i8, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i8>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i16> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i16, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i16>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i32> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i32, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i32>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i64> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i64, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i64>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<i128> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<i128, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<i128>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<isize> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<isize, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<isize>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u8> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u8, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u8>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u16> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u16, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u16>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u32> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u32, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u32>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u64> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u64, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u64>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<u128> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<u128, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<u128>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<usize> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<usize, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<usize>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<f32> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<f32, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<f32>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<f64> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<f64, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<f64>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<bool> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<bool, Self::Error> {
if let Some(e) = self.0.get(0) {
e.parse::<bool>().map_err(|_| ())
} else {
Err(())
}
}
}

impl QueryStringDeserializer<String> for QueryString {
type Error = ();

fn deserialize_query_string(&self) -> Result<String, Self::Error> {
if let Some(e) = self.0.get(0) {
Ok(e.clone())
} else {
Err(())
}
}
}

#[derive(Debug, Clone)]
pub struct ParamString(pub String);

Expand Down
92 changes: 92 additions & 0 deletions rupring_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,24 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
code += format!(r#"query_parameters: vec![],"#).as_str();
code += "};";

let mut define_struct_for_json = "".to_string();
define_struct_for_json +=
format!(r#"#[derive(rupring::serde::Serialize, rupring::serde::Deserialize)]"#).as_str();
define_struct_for_json += format!(r#"pub struct {struct_name}__JSON {{"#).as_str();

let mut json_field_names = vec![];
let mut path_field_names = vec![];
let mut query_field_names = vec![];

for field in ast.fields.iter() {
let mut description = "".to_string();
let mut example = r#""""#.to_string();

let mut field_name = field.ident.as_ref().unwrap().to_string();
let mut field_type = field.ty.to_token_stream().to_string().replace(" ", "");

let original_field_name = field_name.clone();

let attributes = field.attrs.clone();

let mut is_required = true;
Expand Down Expand Up @@ -754,6 +765,8 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
}

if is_path_parameter {
path_field_names.push((original_field_name.clone(), field_type.clone()));

code += format!(
r#"swagger_definition.path_parameters.push(rupring::swagger::json::SwaggerParameter {{
name: "{field_name}".to_string(),
Expand All @@ -773,6 +786,8 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
}

if is_query_parameter {
query_field_names.push((original_field_name.clone(), field_type.clone()));

code += format!(
r#"swagger_definition.query_parameters.push(rupring::swagger::json::SwaggerParameter {{
name: "{field_name}".to_string(),
Expand All @@ -791,6 +806,15 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
continue;
}

json_field_names.push((original_field_name.clone(), field_type.clone()));

define_struct_for_json += format!(
r#"
pub {original_field_name}: {field_type},
"#
)
.as_str();

// Body 파라미터 생성 구현
code += format!(r#"let property_of_type = {field_type}::to_swagger_definition(context);"#)
.as_str();
Expand Down Expand Up @@ -831,11 +855,79 @@ pub fn derive_rupring_doc(item: TokenStream) -> TokenStream {
.as_str();
}

define_struct_for_json += format!(r#"}}"#).as_str();

code += "rupring::swagger::macros::SwaggerDefinitionNode::Object(swagger_definition)";

code += "}";

code += "}";

code += define_struct_for_json.as_str();

let mut request_bind_code = "".to_string();
request_bind_code +=
format!(r#"impl rupring::request::BindFromRequest for {struct_name} {{"#).as_str();

request_bind_code +=
"fn bind(&mut self, request: rupring::request::Request) -> rupring::anyhow::Result<Self> {";
request_bind_code += "use rupring::request::ParamStringDeserializer;";
request_bind_code += "use rupring::request::QueryStringDeserializer;";

request_bind_code += format!("let mut json_bound = rupring::serde_json::from_str::<{struct_name}__JSON>(request.body.as_str()).unwrap();").as_str();

request_bind_code += format!("let bound = {struct_name} {{").as_str();

for (field_name, _) in json_field_names {
request_bind_code += format!("{field_name}: json_bound.{field_name},").as_str();
}

for (field_name, field_type) in path_field_names {
request_bind_code += format!(
r#"{field_name}: {{
let param = rupring::request::ParamString(
request.path_parameters["{field_name}"].clone()
);
let deserialized: {field_type} = match param.deserialize() {{
Ok(v) => v,
Err(_) => return Err(rupring::anyhow::anyhow!("invalid parameter: {field_name}")),
}};
deserialized
}}
"#
)
.as_str();
}

for (field_name, field_type) in query_field_names {
request_bind_code += format!(
r#"{field_name}: {{
let query = rupring::request::QueryString(
request.query_parameters["{field_name}"].clone()
);
let deserialized: {field_type} = match query.deserialize_query_string() {{
Ok(v) => v,
Err(_) => return Err(rupring::anyhow::anyhow!("invalid parameter: {field_name}")),
}};
deserialized
}},
"#
)
.as_str();
}

request_bind_code += format!("}};").as_str();

request_bind_code += "Ok(bound)";
request_bind_code += "}";

request_bind_code += format!(r#"}}"#).as_str();

code += request_bind_code.as_str();

return TokenStream::from_str(code.as_str()).unwrap();
}

0 comments on commit 44ef9f6

Please sign in to comment.